In [25]:
import sys
import torch
from torch import nn
import time
from tqdm import tqdm
import torch.optim as optim
import copy

# Specify where to find the data preparation class
sys.path.append('../../Data_Preparation')
from Preparation import CustomDataLoader

In [2]:
# InceptionV3 training data (ImageNet) properties
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
DIMENSIONS = 3
#SIZE = 256

In [3]:
# Instantiate the CustomDataLoader class for training
train_data_loader = CustomDataLoader(data_path="../../FER2013_Data", batch_size=32, dataset_type="train", mean=MEAN, std=STD, dimensions=DIMENSIONS).data_loader
test_data_loader = CustomDataLoader(data_path="../../FER2013_Data", batch_size=32, dataset_type="test", mean=MEAN, std=STD, dimensions=DIMENSIONS).data_loader

# Confirm correct data load
print("Train Data Loader:")
for batch_idx, (inputs, labels) in enumerate(train_data_loader):
    print("Batch Index:", batch_idx)
    print("Inputs Shape:", inputs.shape)
    print("Labels Shape:", labels.shape)
    # Print the first few labels in the batch
    print("Labels:", labels[:5])
    # Break after printing a few batches
    if batch_idx == 2:
        break

Train Data Loader:
Batch Index: 0
Inputs Shape: torch.Size([32, 3, 299, 299])
Labels Shape: torch.Size([32])
Labels: tensor([2, 6, 0, 3, 2])
Batch Index: 1
Inputs Shape: torch.Size([32, 3, 299, 299])
Labels Shape: torch.Size([32])
Labels: tensor([6, 3, 2, 5, 6])
Batch Index: 2
Inputs Shape: torch.Size([32, 3, 299, 299])
Labels Shape: torch.Size([32])
Labels: tensor([6, 5, 6, 0, 5])


In [45]:
# load up the InceptionV3 model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_ft = models.inception_v3(pretrained=True)

model.aux_logits = False

for parameter in model.parameters():
    parameter.requires_grad = False

model.fc = nn.Sequential(
    nn.Linear(model_ft.fc.in_features, 10),
    nn.Linear(10, 7)
)

model = model.to(device)

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model_ft.parameters()), lr=0.001)

In [47]:
num_epochs = 25

for epoch in tqdm(range(num_epochs)):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
    
    total_batch = len(train_data_loader.dataset)//32

    for i, (batch_images, batch_labels) in enumerate(train_data_loader):
        
        X = batch_images.to(device)
        Y = batch_labels.to(device)

        pre = model(X)
        cost = loss(pre, Y)

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        if (i+1) % 5 == 0:
            print('Epoch [%d/%d], lter [%d/%d] Loss: %.4f'
                 %(epoch+1, num_epochs, i+1, total_batch, cost.item()))

model.eval()

correct = 0
total = 0

for images, labels in test_data_loader:
    
    images = images.to(device)
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.to(device)-1).sum()
    
print('Accuracy of test images: %f %%' % (100 * float(correct) / total))

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch 0/24
----------
Epoch [1/25], lter [5/876] Loss: 1.9459
Epoch [1/25], lter [10/876] Loss: 1.9082
Epoch [1/25], lter [15/876] Loss: 1.9744
Epoch [1/25], lter [20/876] Loss: 2.0286
Epoch [1/25], lter [25/876] Loss: 2.0308
Epoch [1/25], lter [30/876] Loss: 1.9939
Epoch [1/25], lter [35/876] Loss: 1.9843
Epoch [1/25], lter [40/876] Loss: 1.9548
Epoch [1/25], lter [45/876] Loss: 1.9837
Epoch [1/25], lter [50/876] Loss: 1.9853
Epoch [1/25], lter [55/876] Loss: 1.9757
Epoch [1/25], lter [60/876] Loss: 1.9699
Epoch [1/25], lter [65/876] Loss: 1.9789
Epoch [1/25], lter [70/876] Loss: 1.9876
Epoch [1/25], lter [75/876] Loss: 1.9915
Epoch [1/25], lter [80/876] Loss: 1.9935
Epoch [1/25], lter [85/876] Loss: 1.9770
Epoch [1/25], lter [90/876] Loss: 1.9529
Epoch [1/25], lter [95/876] Loss: 1.9444
Epoch [1/25], lter [100/876] Loss: 1.9724
Epoch [1/25], lter [105/876] Loss: 1.9823
Epoch [1/25], lter [110/876] Loss: 1.9896
Epoch [1/25], lter [115/876] Loss: 1.9575
Epoch [1/25], lter [120/876] Los

  4%|▍         | 1/25 [59:14<23:41:37, 3554.06s/it]

Epoch 1/24
----------
Epoch [2/25], lter [5/876] Loss: 1.9891
Epoch [2/25], lter [10/876] Loss: 1.9390
Epoch [2/25], lter [15/876] Loss: 1.9907
Epoch [2/25], lter [20/876] Loss: 1.9649
Epoch [2/25], lter [25/876] Loss: 1.9994
Epoch [2/25], lter [30/876] Loss: 1.9478
Epoch [2/25], lter [35/876] Loss: 1.9730
Epoch [2/25], lter [40/876] Loss: 1.9958
Epoch [2/25], lter [45/876] Loss: 2.0026
Epoch [2/25], lter [50/876] Loss: 1.9470
Epoch [2/25], lter [55/876] Loss: 1.9505
Epoch [2/25], lter [60/876] Loss: 1.9981
Epoch [2/25], lter [65/876] Loss: 1.9695
Epoch [2/25], lter [70/876] Loss: 1.9718
Epoch [2/25], lter [75/876] Loss: 1.9771
Epoch [2/25], lter [80/876] Loss: 2.0006
Epoch [2/25], lter [85/876] Loss: 1.9101
Epoch [2/25], lter [90/876] Loss: 1.9532
Epoch [2/25], lter [95/876] Loss: 1.8949
Epoch [2/25], lter [100/876] Loss: 1.9529
Epoch [2/25], lter [105/876] Loss: 1.9264
Epoch [2/25], lter [110/876] Loss: 1.9295
Epoch [2/25], lter [115/876] Loss: 1.9461
Epoch [2/25], lter [120/876] Los

  8%|▊         | 2/25 [1:53:27<21:34:37, 3377.30s/it]

Epoch 2/24
----------
Epoch [3/25], lter [5/876] Loss: 1.9869
Epoch [3/25], lter [10/876] Loss: 1.9499
Epoch [3/25], lter [15/876] Loss: 1.9677
Epoch [3/25], lter [20/876] Loss: 1.9467
Epoch [3/25], lter [25/876] Loss: 1.9533
Epoch [3/25], lter [30/876] Loss: 1.9929
Epoch [3/25], lter [35/876] Loss: 1.9815
Epoch [3/25], lter [40/876] Loss: 1.9571
Epoch [3/25], lter [45/876] Loss: 1.9711
Epoch [3/25], lter [50/876] Loss: 1.9841
Epoch [3/25], lter [55/876] Loss: 2.0008
Epoch [3/25], lter [60/876] Loss: 1.9522
Epoch [3/25], lter [65/876] Loss: 1.9780
Epoch [3/25], lter [70/876] Loss: 1.9465
Epoch [3/25], lter [75/876] Loss: 1.9505
Epoch [3/25], lter [80/876] Loss: 1.9550
Epoch [3/25], lter [85/876] Loss: 1.9660
Epoch [3/25], lter [90/876] Loss: 1.9762
Epoch [3/25], lter [95/876] Loss: 1.9762
Epoch [3/25], lter [100/876] Loss: 2.0602
Epoch [3/25], lter [105/876] Loss: 1.9649
Epoch [3/25], lter [110/876] Loss: 1.9497
Epoch [3/25], lter [115/876] Loss: 1.9760
Epoch [3/25], lter [120/876] Los

 12%|█▏        | 3/25 [2:52:35<21:06:55, 3455.26s/it]

Epoch 3/24
----------
Epoch [4/25], lter [5/876] Loss: 1.9459
Epoch [4/25], lter [10/876] Loss: 2.0244
Epoch [4/25], lter [15/876] Loss: 1.9868
Epoch [4/25], lter [20/876] Loss: 1.9737
Epoch [4/25], lter [25/876] Loss: 1.9386
Epoch [4/25], lter [30/876] Loss: 1.9847
Epoch [4/25], lter [35/876] Loss: 1.9377
Epoch [4/25], lter [40/876] Loss: 1.9825
Epoch [4/25], lter [45/876] Loss: 1.9758
Epoch [4/25], lter [50/876] Loss: 1.9301
Epoch [4/25], lter [55/876] Loss: 2.0133
Epoch [4/25], lter [60/876] Loss: 2.0012
Epoch [4/25], lter [65/876] Loss: 1.9808
Epoch [4/25], lter [70/876] Loss: 2.0257
Epoch [4/25], lter [75/876] Loss: 1.9760
Epoch [4/25], lter [80/876] Loss: 2.0154
Epoch [4/25], lter [85/876] Loss: 1.9693
Epoch [4/25], lter [90/876] Loss: 2.0091
Epoch [4/25], lter [95/876] Loss: 2.0241
Epoch [4/25], lter [100/876] Loss: 1.9669
Epoch [4/25], lter [105/876] Loss: 1.9809
Epoch [4/25], lter [110/876] Loss: 1.9228
Epoch [4/25], lter [115/876] Loss: 1.9298
Epoch [4/25], lter [120/876] Los

 16%|█▌        | 4/25 [3:47:15<19:45:07, 3386.08s/it]

Epoch 4/24
----------
Epoch [5/25], lter [5/876] Loss: 1.9778
Epoch [5/25], lter [10/876] Loss: 1.9491
Epoch [5/25], lter [15/876] Loss: 1.9613
Epoch [5/25], lter [20/876] Loss: 1.9548
Epoch [5/25], lter [25/876] Loss: 1.9462
Epoch [5/25], lter [30/876] Loss: 1.9550
Epoch [5/25], lter [35/876] Loss: 1.9700
Epoch [5/25], lter [40/876] Loss: 1.9164
Epoch [5/25], lter [45/876] Loss: 1.9524
Epoch [5/25], lter [50/876] Loss: 1.9777
Epoch [5/25], lter [55/876] Loss: 1.9866
Epoch [5/25], lter [60/876] Loss: 1.9649
Epoch [5/25], lter [65/876] Loss: 1.9947
Epoch [5/25], lter [70/876] Loss: 1.9381
Epoch [5/25], lter [75/876] Loss: 2.0089
Epoch [5/25], lter [80/876] Loss: 1.9867
Epoch [5/25], lter [85/876] Loss: 1.9815
Epoch [5/25], lter [90/876] Loss: 1.9739
Epoch [5/25], lter [95/876] Loss: 1.9791
Epoch [5/25], lter [100/876] Loss: 1.9961
Epoch [5/25], lter [105/876] Loss: 1.9579
Epoch [5/25], lter [110/876] Loss: 2.0142
Epoch [5/25], lter [115/876] Loss: 1.9319
Epoch [5/25], lter [120/876] Los

 20%|██        | 5/25 [4:42:22<18:39:08, 3357.43s/it]

Epoch 5/24
----------
Epoch [6/25], lter [5/876] Loss: 1.9124
Epoch [6/25], lter [10/876] Loss: 1.9572
Epoch [6/25], lter [15/876] Loss: 1.9105
Epoch [6/25], lter [20/876] Loss: 2.0444
Epoch [6/25], lter [25/876] Loss: 1.9668
Epoch [6/25], lter [30/876] Loss: 1.9819
Epoch [6/25], lter [35/876] Loss: 2.0127
Epoch [6/25], lter [40/876] Loss: 1.9158
Epoch [6/25], lter [45/876] Loss: 1.9707
Epoch [6/25], lter [50/876] Loss: 1.9567
Epoch [6/25], lter [55/876] Loss: 1.9211
Epoch [6/25], lter [60/876] Loss: 1.9762
Epoch [6/25], lter [65/876] Loss: 1.9215
Epoch [6/25], lter [70/876] Loss: 1.9764
Epoch [6/25], lter [75/876] Loss: 1.9652
Epoch [6/25], lter [80/876] Loss: 2.0106
Epoch [6/25], lter [85/876] Loss: 1.9405
Epoch [6/25], lter [90/876] Loss: 1.9954
Epoch [6/25], lter [95/876] Loss: 1.9549
Epoch [6/25], lter [100/876] Loss: 1.9699
Epoch [6/25], lter [105/876] Loss: 1.9849
Epoch [6/25], lter [110/876] Loss: 1.9723
Epoch [6/25], lter [115/876] Loss: 1.9737
Epoch [6/25], lter [120/876] Los

 24%|██▍       | 6/25 [5:42:32<18:10:25, 3443.44s/it]

Epoch 6/24
----------
Epoch [7/25], lter [5/876] Loss: 1.9431
Epoch [7/25], lter [10/876] Loss: 1.9520
Epoch [7/25], lter [15/876] Loss: 1.9617
Epoch [7/25], lter [20/876] Loss: 1.9376
Epoch [7/25], lter [25/876] Loss: 2.0136
Epoch [7/25], lter [30/876] Loss: 1.9503
Epoch [7/25], lter [35/876] Loss: 1.9167
Epoch [7/25], lter [40/876] Loss: 2.0026
Epoch [7/25], lter [45/876] Loss: 1.9389
Epoch [7/25], lter [50/876] Loss: 2.0301
Epoch [7/25], lter [55/876] Loss: 2.0148
Epoch [7/25], lter [60/876] Loss: 2.0016
Epoch [7/25], lter [65/876] Loss: 2.0184
Epoch [7/25], lter [70/876] Loss: 1.9668
Epoch [7/25], lter [75/876] Loss: 1.9787
Epoch [7/25], lter [80/876] Loss: 1.9933
Epoch [7/25], lter [85/876] Loss: 1.9294
Epoch [7/25], lter [90/876] Loss: 1.9817
Epoch [7/25], lter [95/876] Loss: 1.9989
Epoch [7/25], lter [100/876] Loss: 2.0097
Epoch [7/25], lter [105/876] Loss: 2.0262
Epoch [7/25], lter [110/876] Loss: 1.9801
Epoch [7/25], lter [115/876] Loss: 1.9639
Epoch [7/25], lter [120/876] Los

 28%|██▊       | 7/25 [6:30:41<16:18:40, 3262.24s/it]

Epoch 7/24
----------
Epoch [8/25], lter [5/876] Loss: 1.9720
Epoch [8/25], lter [10/876] Loss: 1.9581
Epoch [8/25], lter [15/876] Loss: 1.9776
Epoch [8/25], lter [20/876] Loss: 1.9644
Epoch [8/25], lter [25/876] Loss: 1.9419
Epoch [8/25], lter [30/876] Loss: 1.9303
Epoch [8/25], lter [35/876] Loss: 1.9669
Epoch [8/25], lter [40/876] Loss: 1.9666
Epoch [8/25], lter [45/876] Loss: 1.9897
Epoch [8/25], lter [50/876] Loss: 1.9913
Epoch [8/25], lter [55/876] Loss: 1.9864
Epoch [8/25], lter [60/876] Loss: 1.9973
Epoch [8/25], lter [65/876] Loss: 1.9615
Epoch [8/25], lter [70/876] Loss: 1.9875
Epoch [8/25], lter [75/876] Loss: 1.9767
Epoch [8/25], lter [80/876] Loss: 1.9617
Epoch [8/25], lter [85/876] Loss: 1.9737
Epoch [8/25], lter [90/876] Loss: 1.9935
Epoch [8/25], lter [95/876] Loss: 2.0125
Epoch [8/25], lter [100/876] Loss: 1.9855
Epoch [8/25], lter [105/876] Loss: 1.9696
Epoch [8/25], lter [110/876] Loss: 1.9388
Epoch [8/25], lter [115/876] Loss: 1.9138
Epoch [8/25], lter [120/876] Los

 32%|███▏      | 8/25 [7:18:50<14:50:34, 3143.20s/it]

Epoch 8/24
----------
Epoch [9/25], lter [5/876] Loss: 1.9389
Epoch [9/25], lter [10/876] Loss: 1.9226
Epoch [9/25], lter [15/876] Loss: 1.9780
Epoch [9/25], lter [20/876] Loss: 1.9868
Epoch [9/25], lter [25/876] Loss: 1.9452
Epoch [9/25], lter [30/876] Loss: 1.9712
Epoch [9/25], lter [35/876] Loss: 2.0091
Epoch [9/25], lter [40/876] Loss: 1.9551
Epoch [9/25], lter [45/876] Loss: 1.9571
Epoch [9/25], lter [50/876] Loss: 1.9435
Epoch [9/25], lter [55/876] Loss: 1.9526
Epoch [9/25], lter [60/876] Loss: 1.9611
Epoch [9/25], lter [65/876] Loss: 1.9831
Epoch [9/25], lter [70/876] Loss: 1.9704
Epoch [9/25], lter [75/876] Loss: 1.9350
Epoch [9/25], lter [80/876] Loss: 1.9784
Epoch [9/25], lter [85/876] Loss: 1.9320
Epoch [9/25], lter [90/876] Loss: 2.0397
Epoch [9/25], lter [95/876] Loss: 1.9489
Epoch [9/25], lter [100/876] Loss: 1.9307
Epoch [9/25], lter [105/876] Loss: 1.9776
Epoch [9/25], lter [110/876] Loss: 1.9688
Epoch [9/25], lter [115/876] Loss: 2.0097
Epoch [9/25], lter [120/876] Los

 36%|███▌      | 9/25 [8:07:12<13:38:06, 3067.89s/it]

Epoch 9/24
----------
Epoch [10/25], lter [5/876] Loss: 1.9885
Epoch [10/25], lter [10/876] Loss: 1.9258
Epoch [10/25], lter [15/876] Loss: 2.0037
Epoch [10/25], lter [20/876] Loss: 1.9991
Epoch [10/25], lter [25/876] Loss: 1.9393
Epoch [10/25], lter [30/876] Loss: 2.0289
Epoch [10/25], lter [35/876] Loss: 1.9934
Epoch [10/25], lter [40/876] Loss: 1.9640
Epoch [10/25], lter [45/876] Loss: 1.9546
Epoch [10/25], lter [50/876] Loss: 2.0114
Epoch [10/25], lter [55/876] Loss: 1.9027
Epoch [10/25], lter [60/876] Loss: 1.9664
Epoch [10/25], lter [65/876] Loss: 1.9677
Epoch [10/25], lter [70/876] Loss: 1.9563
Epoch [10/25], lter [75/876] Loss: 1.9571
Epoch [10/25], lter [80/876] Loss: 1.9676
Epoch [10/25], lter [85/876] Loss: 2.0213
Epoch [10/25], lter [90/876] Loss: 1.9657
Epoch [10/25], lter [95/876] Loss: 2.0125
Epoch [10/25], lter [100/876] Loss: 1.9510
Epoch [10/25], lter [105/876] Loss: 1.9467
Epoch [10/25], lter [110/876] Loss: 1.9282
Epoch [10/25], lter [115/876] Loss: 1.9440
Epoch [10

 40%|████      | 10/25 [8:55:35<12:34:16, 3017.07s/it]

Epoch 10/24
----------
Epoch [11/25], lter [5/876] Loss: 1.9532
Epoch [11/25], lter [10/876] Loss: 2.0286
Epoch [11/25], lter [15/876] Loss: 2.0169
Epoch [11/25], lter [20/876] Loss: 1.9726
Epoch [11/25], lter [25/876] Loss: 1.9469
Epoch [11/25], lter [30/876] Loss: 2.0243
Epoch [11/25], lter [35/876] Loss: 1.9605
Epoch [11/25], lter [40/876] Loss: 1.9762
Epoch [11/25], lter [45/876] Loss: 1.9544
Epoch [11/25], lter [50/876] Loss: 1.9808
Epoch [11/25], lter [55/876] Loss: 1.9588
Epoch [11/25], lter [60/876] Loss: 1.9361
Epoch [11/25], lter [65/876] Loss: 2.0106
Epoch [11/25], lter [70/876] Loss: 2.0042
Epoch [11/25], lter [75/876] Loss: 1.9833
Epoch [11/25], lter [80/876] Loss: 1.9957
Epoch [11/25], lter [85/876] Loss: 1.9807
Epoch [11/25], lter [90/876] Loss: 2.0026
Epoch [11/25], lter [95/876] Loss: 2.0153
Epoch [11/25], lter [100/876] Loss: 1.9909
Epoch [11/25], lter [105/876] Loss: 2.0084
Epoch [11/25], lter [110/876] Loss: 2.0072
Epoch [11/25], lter [115/876] Loss: 2.0215
Epoch [1

 44%|████▍     | 11/25 [9:43:51<11:35:20, 2980.02s/it]

Epoch 11/24
----------
Epoch [12/25], lter [5/876] Loss: 2.0340
Epoch [12/25], lter [10/876] Loss: 1.9927
Epoch [12/25], lter [15/876] Loss: 2.0001
Epoch [12/25], lter [20/876] Loss: 1.9797
Epoch [12/25], lter [25/876] Loss: 1.9642
Epoch [12/25], lter [30/876] Loss: 1.9328
Epoch [12/25], lter [35/876] Loss: 1.9590
Epoch [12/25], lter [40/876] Loss: 1.9954
Epoch [12/25], lter [45/876] Loss: 1.9564
Epoch [12/25], lter [50/876] Loss: 1.9496
Epoch [12/25], lter [55/876] Loss: 1.9919
Epoch [12/25], lter [60/876] Loss: 1.9448
Epoch [12/25], lter [65/876] Loss: 1.9669
Epoch [12/25], lter [70/876] Loss: 1.9561
Epoch [12/25], lter [75/876] Loss: 1.9873
Epoch [12/25], lter [80/876] Loss: 1.9698
Epoch [12/25], lter [85/876] Loss: 1.9672
Epoch [12/25], lter [90/876] Loss: 2.0373
Epoch [12/25], lter [95/876] Loss: 1.9251
Epoch [12/25], lter [100/876] Loss: 1.9820
Epoch [12/25], lter [105/876] Loss: 1.9476
Epoch [12/25], lter [110/876] Loss: 1.9593
Epoch [12/25], lter [115/876] Loss: 1.9778
Epoch [1

 48%|████▊     | 12/25 [10:32:21<10:41:02, 2958.68s/it]

Epoch 12/24
----------
Epoch [13/25], lter [5/876] Loss: 1.9339
Epoch [13/25], lter [10/876] Loss: 1.9810
Epoch [13/25], lter [15/876] Loss: 1.9437
Epoch [13/25], lter [20/876] Loss: 1.9656
Epoch [13/25], lter [25/876] Loss: 1.9391
Epoch [13/25], lter [30/876] Loss: 1.9543
Epoch [13/25], lter [35/876] Loss: 2.0034
Epoch [13/25], lter [40/876] Loss: 1.9774
Epoch [13/25], lter [45/876] Loss: 1.9637
Epoch [13/25], lter [50/876] Loss: 1.9705
Epoch [13/25], lter [55/876] Loss: 1.9549
Epoch [13/25], lter [60/876] Loss: 1.9823
Epoch [13/25], lter [65/876] Loss: 1.9403
Epoch [13/25], lter [70/876] Loss: 2.0165
Epoch [13/25], lter [75/876] Loss: 1.9886
Epoch [13/25], lter [80/876] Loss: 2.0259
Epoch [13/25], lter [85/876] Loss: 1.9891
Epoch [13/25], lter [90/876] Loss: 2.0081
Epoch [13/25], lter [95/876] Loss: 1.9447
Epoch [13/25], lter [100/876] Loss: 1.9643
Epoch [13/25], lter [105/876] Loss: 2.0007
Epoch [13/25], lter [110/876] Loss: 1.9853
Epoch [13/25], lter [115/876] Loss: 1.9404
Epoch [1

 52%|█████▏    | 13/25 [11:20:44<9:48:22, 2941.87s/it] 

Epoch 13/24
----------
Epoch [14/25], lter [5/876] Loss: 2.0172
Epoch [14/25], lter [10/876] Loss: 1.9749
Epoch [14/25], lter [15/876] Loss: 1.9658
Epoch [14/25], lter [20/876] Loss: 1.9728
Epoch [14/25], lter [25/876] Loss: 1.9565
Epoch [14/25], lter [30/876] Loss: 1.9815
Epoch [14/25], lter [35/876] Loss: 1.9750
Epoch [14/25], lter [40/876] Loss: 1.9769
Epoch [14/25], lter [45/876] Loss: 2.0041
Epoch [14/25], lter [50/876] Loss: 1.9248
Epoch [14/25], lter [55/876] Loss: 1.9956
Epoch [14/25], lter [60/876] Loss: 1.9890
Epoch [14/25], lter [65/876] Loss: 2.0055
Epoch [14/25], lter [70/876] Loss: 1.9752
Epoch [14/25], lter [75/876] Loss: 1.9793
Epoch [14/25], lter [80/876] Loss: 1.9725
Epoch [14/25], lter [85/876] Loss: 1.9903
Epoch [14/25], lter [90/876] Loss: 1.9310
Epoch [14/25], lter [95/876] Loss: 2.0310
Epoch [14/25], lter [100/876] Loss: 1.9038
Epoch [14/25], lter [105/876] Loss: 1.9904
Epoch [14/25], lter [110/876] Loss: 1.9494
Epoch [14/25], lter [115/876] Loss: 1.9958
Epoch [1

 56%|█████▌    | 14/25 [12:09:03<8:56:58, 2928.93s/it]

Epoch 14/24
----------
Epoch [15/25], lter [5/876] Loss: 1.9977
Epoch [15/25], lter [10/876] Loss: 1.9994
Epoch [15/25], lter [15/876] Loss: 1.9518
Epoch [15/25], lter [20/876] Loss: 1.9698
Epoch [15/25], lter [25/876] Loss: 1.9543
Epoch [15/25], lter [30/876] Loss: 1.9512
Epoch [15/25], lter [35/876] Loss: 1.9374
Epoch [15/25], lter [40/876] Loss: 1.9711
Epoch [15/25], lter [45/876] Loss: 1.9298
Epoch [15/25], lter [50/876] Loss: 1.9902
Epoch [15/25], lter [55/876] Loss: 1.9800
Epoch [15/25], lter [60/876] Loss: 1.9937
Epoch [15/25], lter [65/876] Loss: 1.9934
Epoch [15/25], lter [70/876] Loss: 1.9844
Epoch [15/25], lter [75/876] Loss: 1.9806
Epoch [15/25], lter [80/876] Loss: 1.9602
Epoch [15/25], lter [85/876] Loss: 1.9670
Epoch [15/25], lter [90/876] Loss: 1.9910
Epoch [15/25], lter [95/876] Loss: 1.9791
Epoch [15/25], lter [100/876] Loss: 1.9837
Epoch [15/25], lter [105/876] Loss: 1.9197
Epoch [15/25], lter [110/876] Loss: 1.9277
Epoch [15/25], lter [115/876] Loss: 1.9734
Epoch [1

 60%|██████    | 15/25 [12:57:33<8:07:10, 2923.05s/it]

Epoch 15/24
----------
Epoch [16/25], lter [5/876] Loss: 1.9169
Epoch [16/25], lter [10/876] Loss: 2.0168
Epoch [16/25], lter [15/876] Loss: 1.9248
Epoch [16/25], lter [20/876] Loss: 1.9492
Epoch [16/25], lter [25/876] Loss: 1.9434
Epoch [16/25], lter [30/876] Loss: 1.9586
Epoch [16/25], lter [35/876] Loss: 1.9912
Epoch [16/25], lter [40/876] Loss: 2.0130
Epoch [16/25], lter [45/876] Loss: 1.9533
Epoch [16/25], lter [50/876] Loss: 1.9249
Epoch [16/25], lter [55/876] Loss: 1.9021
Epoch [16/25], lter [60/876] Loss: 2.0265
Epoch [16/25], lter [65/876] Loss: 1.9415
Epoch [16/25], lter [70/876] Loss: 1.9663
Epoch [16/25], lter [75/876] Loss: 1.9861
Epoch [16/25], lter [80/876] Loss: 1.9951
Epoch [16/25], lter [85/876] Loss: 1.9578
Epoch [16/25], lter [90/876] Loss: 1.9922
Epoch [16/25], lter [95/876] Loss: 1.9062
Epoch [16/25], lter [100/876] Loss: 1.9806
Epoch [16/25], lter [105/876] Loss: 1.9854
Epoch [16/25], lter [110/876] Loss: 1.9869
Epoch [16/25], lter [115/876] Loss: 1.9762
Epoch [1

 64%|██████▍   | 16/25 [13:45:55<7:17:31, 2916.82s/it]

Epoch 16/24
----------
Epoch [17/25], lter [5/876] Loss: 1.9783
Epoch [17/25], lter [10/876] Loss: 2.0279
Epoch [17/25], lter [15/876] Loss: 1.9956
Epoch [17/25], lter [20/876] Loss: 1.9855
Epoch [17/25], lter [25/876] Loss: 1.9555
Epoch [17/25], lter [30/876] Loss: 1.9782
Epoch [17/25], lter [35/876] Loss: 1.9364
Epoch [17/25], lter [40/876] Loss: 1.9441
Epoch [17/25], lter [45/876] Loss: 2.0195
Epoch [17/25], lter [50/876] Loss: 1.9798
Epoch [17/25], lter [55/876] Loss: 1.9741
Epoch [17/25], lter [60/876] Loss: 1.9505
Epoch [17/25], lter [65/876] Loss: 2.0184
Epoch [17/25], lter [70/876] Loss: 1.9694
Epoch [17/25], lter [75/876] Loss: 1.9582
Epoch [17/25], lter [80/876] Loss: 1.9988
Epoch [17/25], lter [85/876] Loss: 1.9482
Epoch [17/25], lter [90/876] Loss: 1.9133
Epoch [17/25], lter [95/876] Loss: 1.9930
Epoch [17/25], lter [100/876] Loss: 1.9677
Epoch [17/25], lter [105/876] Loss: 1.9196
Epoch [17/25], lter [110/876] Loss: 1.9560
Epoch [17/25], lter [115/876] Loss: 2.0130
Epoch [1

 68%|██████▊   | 17/25 [14:34:22<6:28:31, 2913.95s/it]

Epoch 17/24
----------
Epoch [18/25], lter [5/876] Loss: 1.9763
Epoch [18/25], lter [10/876] Loss: 1.9799
Epoch [18/25], lter [15/876] Loss: 1.9856
Epoch [18/25], lter [20/876] Loss: 1.9748
Epoch [18/25], lter [25/876] Loss: 1.9425
Epoch [18/25], lter [30/876] Loss: 1.9351
Epoch [18/25], lter [35/876] Loss: 2.0032
Epoch [18/25], lter [40/876] Loss: 1.9917
Epoch [18/25], lter [45/876] Loss: 1.9605
Epoch [18/25], lter [50/876] Loss: 1.9526
Epoch [18/25], lter [55/876] Loss: 1.9054
Epoch [18/25], lter [60/876] Loss: 2.0112
Epoch [18/25], lter [65/876] Loss: 2.0050
Epoch [18/25], lter [70/876] Loss: 1.9585
Epoch [18/25], lter [75/876] Loss: 1.9989
Epoch [18/25], lter [80/876] Loss: 2.0708
Epoch [18/25], lter [85/876] Loss: 1.9605
Epoch [18/25], lter [90/876] Loss: 2.0073
Epoch [18/25], lter [95/876] Loss: 1.9655
Epoch [18/25], lter [100/876] Loss: 1.9886
Epoch [18/25], lter [105/876] Loss: 1.9894
Epoch [18/25], lter [110/876] Loss: 1.9432
Epoch [18/25], lter [115/876] Loss: 1.9537
Epoch [1

 72%|███████▏  | 18/25 [15:22:54<5:39:53, 2913.33s/it]

Epoch 18/24
----------
Epoch [19/25], lter [5/876] Loss: 1.8989
Epoch [19/25], lter [10/876] Loss: 2.0030
Epoch [19/25], lter [15/876] Loss: 1.9613
Epoch [19/25], lter [20/876] Loss: 1.9796
Epoch [19/25], lter [25/876] Loss: 2.0101
Epoch [19/25], lter [30/876] Loss: 1.9755
Epoch [19/25], lter [35/876] Loss: 1.9819
Epoch [19/25], lter [40/876] Loss: 2.0241
Epoch [19/25], lter [45/876] Loss: 1.9478
Epoch [19/25], lter [50/876] Loss: 1.9816
Epoch [19/25], lter [55/876] Loss: 2.0362
Epoch [19/25], lter [60/876] Loss: 1.9724
Epoch [19/25], lter [65/876] Loss: 2.0158
Epoch [19/25], lter [70/876] Loss: 2.0151
Epoch [19/25], lter [75/876] Loss: 1.9894
Epoch [19/25], lter [80/876] Loss: 1.9848
Epoch [19/25], lter [85/876] Loss: 2.0167
Epoch [19/25], lter [90/876] Loss: 1.9340
Epoch [19/25], lter [95/876] Loss: 1.9725
Epoch [19/25], lter [100/876] Loss: 1.9738
Epoch [19/25], lter [105/876] Loss: 2.0207
Epoch [19/25], lter [110/876] Loss: 1.9634
Epoch [19/25], lter [115/876] Loss: 1.9811
Epoch [1

 76%|███████▌  | 19/25 [16:13:17<4:54:36, 2946.15s/it]

Epoch 19/24
----------
Epoch [20/25], lter [5/876] Loss: 1.9947
Epoch [20/25], lter [10/876] Loss: 1.9464
Epoch [20/25], lter [15/876] Loss: 1.9448
Epoch [20/25], lter [20/876] Loss: 1.9650
Epoch [20/25], lter [25/876] Loss: 1.9814
Epoch [20/25], lter [30/876] Loss: 1.9830
Epoch [20/25], lter [35/876] Loss: 1.9986
Epoch [20/25], lter [40/876] Loss: 2.0379
Epoch [20/25], lter [45/876] Loss: 2.0515
Epoch [20/25], lter [50/876] Loss: 1.9272
Epoch [20/25], lter [55/876] Loss: 2.0049
Epoch [20/25], lter [60/876] Loss: 1.9748
Epoch [20/25], lter [65/876] Loss: 1.9156
Epoch [20/25], lter [70/876] Loss: 1.9687
Epoch [20/25], lter [75/876] Loss: 1.9708
Epoch [20/25], lter [80/876] Loss: 1.9527
Epoch [20/25], lter [85/876] Loss: 1.9969
Epoch [20/25], lter [90/876] Loss: 1.9590
Epoch [20/25], lter [95/876] Loss: 1.9990
Epoch [20/25], lter [100/876] Loss: 1.9374
Epoch [20/25], lter [105/876] Loss: 1.9409
Epoch [20/25], lter [110/876] Loss: 1.9374
Epoch [20/25], lter [115/876] Loss: 2.0278
Epoch [2

 80%|████████  | 20/25 [17:03:24<4:07:01, 2964.32s/it]

Epoch 20/24
----------
Epoch [21/25], lter [5/876] Loss: 1.9561
Epoch [21/25], lter [10/876] Loss: 1.9601
Epoch [21/25], lter [15/876] Loss: 2.0122
Epoch [21/25], lter [20/876] Loss: 1.9367
Epoch [21/25], lter [25/876] Loss: 1.9194
Epoch [21/25], lter [30/876] Loss: 1.9429
Epoch [21/25], lter [35/876] Loss: 1.9762
Epoch [21/25], lter [40/876] Loss: 1.9676
Epoch [21/25], lter [45/876] Loss: 2.0145
Epoch [21/25], lter [50/876] Loss: 2.0281
Epoch [21/25], lter [55/876] Loss: 1.9791
Epoch [21/25], lter [60/876] Loss: 1.9669
Epoch [21/25], lter [65/876] Loss: 2.0044
Epoch [21/25], lter [70/876] Loss: 1.9867
Epoch [21/25], lter [75/876] Loss: 1.9968
Epoch [21/25], lter [80/876] Loss: 1.9711
Epoch [21/25], lter [85/876] Loss: 2.0362
Epoch [21/25], lter [90/876] Loss: 1.9847
Epoch [21/25], lter [95/876] Loss: 2.0022
Epoch [21/25], lter [100/876] Loss: 1.9807
Epoch [21/25], lter [105/876] Loss: 1.9040
Epoch [21/25], lter [110/876] Loss: 1.9420
Epoch [21/25], lter [115/876] Loss: 2.0210
Epoch [2

 84%|████████▍ | 21/25 [18:05:17<3:32:36, 3189.22s/it]

Epoch 21/24
----------
Epoch [22/25], lter [5/876] Loss: 2.0175
Epoch [22/25], lter [10/876] Loss: 1.9670
Epoch [22/25], lter [15/876] Loss: 1.9916
Epoch [22/25], lter [20/876] Loss: 1.9894
Epoch [22/25], lter [25/876] Loss: 1.9805
Epoch [22/25], lter [30/876] Loss: 1.9850
Epoch [22/25], lter [35/876] Loss: 1.9573
Epoch [22/25], lter [40/876] Loss: 1.9076
Epoch [22/25], lter [45/876] Loss: 2.0162
Epoch [22/25], lter [50/876] Loss: 1.9932
Epoch [22/25], lter [55/876] Loss: 1.9971
Epoch [22/25], lter [60/876] Loss: 1.9562
Epoch [22/25], lter [65/876] Loss: 1.9755
Epoch [22/25], lter [70/876] Loss: 1.9660
Epoch [22/25], lter [75/876] Loss: 1.9582
Epoch [22/25], lter [80/876] Loss: 1.9744
Epoch [22/25], lter [85/876] Loss: 2.0110
Epoch [22/25], lter [90/876] Loss: 1.9760
Epoch [22/25], lter [95/876] Loss: 2.0060
Epoch [22/25], lter [100/876] Loss: 1.9912
Epoch [22/25], lter [105/876] Loss: 2.0157
Epoch [22/25], lter [110/876] Loss: 1.9577
Epoch [22/25], lter [115/876] Loss: 1.9631
Epoch [2

 88%|████████▊ | 22/25 [19:02:32<2:43:08, 3262.87s/it]

Epoch 22/24
----------
Epoch [23/25], lter [5/876] Loss: 1.9871
Epoch [23/25], lter [10/876] Loss: 2.0234
Epoch [23/25], lter [15/876] Loss: 2.0165
Epoch [23/25], lter [20/876] Loss: 1.9630
Epoch [23/25], lter [25/876] Loss: 1.9844
Epoch [23/25], lter [30/876] Loss: 2.0285
Epoch [23/25], lter [35/876] Loss: 1.9752
Epoch [23/25], lter [40/876] Loss: 1.9699
Epoch [23/25], lter [45/876] Loss: 2.0231
Epoch [23/25], lter [50/876] Loss: 1.9487
Epoch [23/25], lter [55/876] Loss: 1.9546
Epoch [23/25], lter [60/876] Loss: 1.9557
Epoch [23/25], lter [65/876] Loss: 2.0036
Epoch [23/25], lter [70/876] Loss: 1.9531
Epoch [23/25], lter [75/876] Loss: 2.0051
Epoch [23/25], lter [80/876] Loss: 1.9466
Epoch [23/25], lter [85/876] Loss: 1.9753
Epoch [23/25], lter [90/876] Loss: 2.0108
Epoch [23/25], lter [95/876] Loss: 2.0005
Epoch [23/25], lter [100/876] Loss: 1.9936
Epoch [23/25], lter [105/876] Loss: 1.9674
Epoch [23/25], lter [110/876] Loss: 1.9314
Epoch [23/25], lter [115/876] Loss: 1.9698
Epoch [2

 92%|█████████▏| 23/25 [20:00:17<1:50:47, 3323.63s/it]

Epoch 23/24
----------
Epoch [24/25], lter [5/876] Loss: 1.9973
Epoch [24/25], lter [10/876] Loss: 2.0050
Epoch [24/25], lter [15/876] Loss: 1.9513
Epoch [24/25], lter [20/876] Loss: 1.9088
Epoch [24/25], lter [25/876] Loss: 1.9444
Epoch [24/25], lter [30/876] Loss: 1.9716
Epoch [24/25], lter [35/876] Loss: 2.0158
Epoch [24/25], lter [40/876] Loss: 1.9810
Epoch [24/25], lter [45/876] Loss: 1.9577
Epoch [24/25], lter [50/876] Loss: 2.0360
Epoch [24/25], lter [55/876] Loss: 1.9849
Epoch [24/25], lter [60/876] Loss: 2.0012
Epoch [24/25], lter [65/876] Loss: 1.9792
Epoch [24/25], lter [70/876] Loss: 1.9813
Epoch [24/25], lter [75/876] Loss: 2.0058
Epoch [24/25], lter [80/876] Loss: 1.9815
Epoch [24/25], lter [85/876] Loss: 2.0121
Epoch [24/25], lter [90/876] Loss: 1.9473
Epoch [24/25], lter [95/876] Loss: 2.0075
Epoch [24/25], lter [100/876] Loss: 1.9752
Epoch [24/25], lter [105/876] Loss: 2.0039
Epoch [24/25], lter [110/876] Loss: 1.9565
Epoch [24/25], lter [115/876] Loss: 1.9650
Epoch [2

 96%|█████████▌| 24/25 [21:01:56<57:16, 3436.27s/it]  

Epoch 24/24
----------
Epoch [25/25], lter [5/876] Loss: 2.0095
Epoch [25/25], lter [10/876] Loss: 1.9484
Epoch [25/25], lter [15/876] Loss: 1.9448
Epoch [25/25], lter [20/876] Loss: 1.9782
Epoch [25/25], lter [25/876] Loss: 1.9609
Epoch [25/25], lter [30/876] Loss: 1.9796
Epoch [25/25], lter [35/876] Loss: 1.9139
Epoch [25/25], lter [40/876] Loss: 1.9884
Epoch [25/25], lter [45/876] Loss: 1.9350
Epoch [25/25], lter [50/876] Loss: 1.9400
Epoch [25/25], lter [55/876] Loss: 1.9600
Epoch [25/25], lter [60/876] Loss: 2.0028
Epoch [25/25], lter [65/876] Loss: 1.9646
Epoch [25/25], lter [70/876] Loss: 1.9504
Epoch [25/25], lter [75/876] Loss: 2.0090
Epoch [25/25], lter [80/876] Loss: 2.0156
Epoch [25/25], lter [85/876] Loss: 1.9581
Epoch [25/25], lter [90/876] Loss: 1.9582
Epoch [25/25], lter [95/876] Loss: 1.9992
Epoch [25/25], lter [100/876] Loss: 1.9665
Epoch [25/25], lter [105/876] Loss: 1.9443
Epoch [25/25], lter [110/876] Loss: 2.0233
Epoch [25/25], lter [115/876] Loss: 1.9794
Epoch [2

100%|██████████| 25/25 [21:55:30<00:00, 3157.20s/it]


Accuracy of test images: 15.480006 %


In [48]:
torch.save(model, 'inception.pt')

# GARBAGE

In [None]:
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(5)):
	# set the model in training mode
	model.train()
	# initialize the total training and validation loss
	totalTrainLoss = 0
	totalValLoss = 0
	# initialize the number of correct predictions in the training
	# and validation step
	trainCorrect = 0
	valCorrect = 0
	# loop over the training set
	for (i, (x, y)) in enumerate(train_data_loader):
		# send the input to the device
		#(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
		# perform a forward pass and calculate the training loss
		pred = model(x)
		loss = lossFunc(pred, y)
		# calculate the gradients
		loss.backward()
		# check if we are updating the model parameters and if so
		# update them, and zero out the previously accumulated gradients
		if (i + 2) % 2 == 0:
			opt.step()
			opt.zero_grad()
		# add the loss to the total training loss so far and
		# calculate the number of correct predictions
		totalTrainLoss += loss
		trainCorrect += (pred.argmax(1) == y).type(
			torch.float).sum().item()

    	# switch off autograd
	with torch.no_grad():
		# set the model in evaluation mode
		model.eval()
		# loop over the validation set
		for (x, y) in test_data_loader:
			# send the input to the device
			#(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
			# make the predictions and calculate the validation loss
			pred = model(x)
			totalValLoss += lossFunc(pred, y)
			# calculate the number of correct predictions
			valCorrect += (pred.argmax(1) == y).type(
				torch.float).sum().item()

    	# calculate the average training and validation loss
	avgTrainLoss = totalTrainLoss / trainSteps
	avgValLoss = totalValLoss / valSteps
	# calculate the training and validation accuracy
	trainCorrect = trainCorrect / len(trainDS)
	valCorrect = valCorrect / len(valDS)
	# update our training history
	H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
	H["train_acc"].append(trainCorrect)
	H["val_loss"].append(avgValLoss.cpu().detach().numpy())
	H["val_acc"].append(valCorrect)
	# print the model training and validation information
	print("[INFO] EPOCH: {}/{}".format(e + 1, 5))
	print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
		avgTrainLoss, trainCorrect))
	print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
		avgValLoss, valCorrect))

# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
	endTime - startTime))
# plot the training loss and accuracy
# plt.style.use("ggplot")
# plt.figure()
# plt.plot(H["train_loss"], label="train_loss")
# plt.plot(H["val_loss"], label="val_loss")
# plt.plot(H["train_acc"], label="train_acc")
# plt.plot(H["val_acc"], label="val_acc")
# plt.title("Training Loss and Accuracy on Dataset")
# plt.xlabel("Epoch #")
# plt.ylabel("Loss/Accuracy")
# plt.legend(loc="lower left")
# plt.savefig(config.WARMUP_PLOT)
# serialize the model to disk
torch.save(model, config.WARMUP_MODEL)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_corrects = 0
            # Iterate over data.
            for inputs, labels in train_data_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()
            epoch_loss = running_loss / dataset_sizes
            epoch_acc = running_corrects.double() / dataset_sizes
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            # deep copy the model
        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [28]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

Epoch 0/24
----------


TypeError: max() received an invalid combination of arguments - got (InceptionOutputs, int), but expected one of:
 * (Tensor input, *, Tensor out)
 * (Tensor input, Tensor other, *, Tensor out)
 * (Tensor input, int dim, bool keepdim, *, tuple of Tensors out)
 * (Tensor input, name dim, bool keepdim, *, tuple of Tensors out)
