In [1]:
#!pip install datasets
#!pip install --upgrade --force-reinstall huggingface_hub

In [2]:
import torch
import torchvision
from torchvision.transforms import transforms, Lambda, Resize
from torchvision import datasets, transforms, models
import os
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Specify where to find the data preparation class
import sys
sys.path.append('../../Data_Preparation')
from Preparation import CustomDataLoader

In [3]:
# ResNet50 training data (ImageNet) properties
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
#DIMENSIONS = 3
#SIZE = 256

LR = 0.1
MOMENTUM=0.85
WEIGHT_DECAY = 1e-4
LR_STEP_SIZE = 5
LR_GAMMA = 0.1

In [4]:
dataloaders = {x: CustomDataLoader(data_path="../../FER2013_Data", batch_size=16, dataset_type=x, mean=MEAN, std=STD, dimensions=3).data_loader for x in ['train', 'test']}
dataset_sizes = {x: len(dataloaders[x]) for x in ['train', 'test']}
print(dataset_sizes)

class_names = dataloaders['train'].dataset.classes
print(class_names)

# Confirm correct data load
print("Train Data Loader:")
for batch_idx, (inputs, labels) in enumerate(dataloaders['train']):
    print("Batch Index:", batch_idx)
    print("Inputs Shape:", inputs.shape)
    print("Labels Shape:", labels.shape)
    # Print the first few labels in the batch
    print("Labels:", labels[:5])
    # Break after printing a few batches
    if batch_idx == 2:
        break

{'train': 1753, 'test': 449}
['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
Train Data Loader:
Batch Index: 0
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([6, 4, 3, 0, 0])
Batch Index: 1
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([3, 3, 3, 4, 5])
Batch Index: 2
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([5, 0, 6, 6, 0])


In [5]:
dataloaders['test']

<torch.utils.data.dataloader.DataLoader at 0x1bde0128e20>

In [6]:
# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)

# Freeze all layers except the final classification layer
for name, param in model.named_parameters():
    if "fc" in name:  # Unfreeze the final classification layer
        param.requires_grad = True
    else:
        param.requires_grad = False

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)

LR_SCHEDULER = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_GAMMA)


# Move the model to the GPU if available
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [7]:
# num_epochs = 20
# for epoch in tqdm(range(num_epochs)):
#     for phase in ['train', 'test']:
#         if phase == 'train':
#             model.train()
#         else:
#             model.eval()

#         running_loss = 0.0
#         running_corrects = 0

#         for inputs, labels in dataloaders[phase]:
#             inputs = inputs.to(device)
#             labels = labels.to(device)

#             if phase == 'train':

#                 optimizer.zero_grad()

#                 with torch.set_grad_enabled(phase == 'train'):
#                     outputs = model(inputs)
#                     _, preds = torch.max(outputs, 1)
#                     loss = criterion(outputs, labels)

#                     if phase == 'train':
#                         loss.backward()
#                         optimizer.step()

#                 running_loss += loss.item() * inputs.size(0)
#                 running_corrects += torch.sum(preds == labels.data)

#             LR_SCHEDULER.step()

#             elif phase == 'test':
#                 with torch.inference_mode():  
#                     running_loss = 0
#                     for step, (images, labels) in enumerate(dataloaders[phase]):
#                         images, labels = images.to(device), labels.to(device)
                        
#                         outputs = model(images)
#                         loss = criterion(outputs, labels)

#                         running_loss += loss.item()

#     epoch_loss = running_loss / dataset_sizes[phase]
#     epoch_acc = running_corrects.double() / dataset_sizes[phase]

#     print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# print("Training complete!")

In [8]:
num_epochs = 25

print("[INFO] Start training")
print("---------------------")

for epoch in tqdm(range(num_epochs)):

    total_batch = len(dataloaders['train'].dataset)//16
    
    # Training phase
    model.train()  # Set the model to train mode

    # Get statistics
    epoch_loss = 0
    len_dataset = 0
    
    for step, (batch_images, batch_labels) in enumerate(dataloaders['train']):
        X, Y = batch_images.to(device), batch_labels.to(device)

        pred = model(X)
        loss = criterion(pred, Y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += pred.shape[0] * loss.item()
        len_dataset += pred.shape[0]
        if (step) % 10 == 0:
            print('Epoch [%d/%d], lter [%d/%d] Loss: %.4f'
                 %(epoch+1, num_epochs, step+1, total_batch, loss.item()))

    epoch_loss = epoch_loss/ len_dataset
    print('Epoch: ', epoch+1, '| train loss : %0.4f' % epoch_loss)

    LR_SCHEDULER.step()
    
    # Validation phase
    model.eval()  # Set the model to evaluation mode
    with torch.inference_mode():  
        running_loss = 0
        for step, (images, labels) in enumerate(dataloaders['test']):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            
    running_loss = running_loss / len(dataloaders['test'])
    print('Epoch: ', epoch, '| test loss : %0.4f' % running_loss )

[INFO] Start training
---------------------


  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [1/25], lter [1/1752] Loss: 8.3915
Epoch [1/25], lter [11/1752] Loss: 43.5536
Epoch [1/25], lter [21/1752] Loss: 48.0594
Epoch [1/25], lter [31/1752] Loss: 30.8403
Epoch [1/25], lter [41/1752] Loss: 11.7199
Epoch [1/25], lter [51/1752] Loss: 55.1133
Epoch [1/25], lter [61/1752] Loss: 30.1797
Epoch [1/25], lter [71/1752] Loss: 43.7683
Epoch [1/25], lter [81/1752] Loss: 37.7293
Epoch [1/25], lter [91/1752] Loss: 11.4472
Epoch [1/25], lter [101/1752] Loss: 23.5558
Epoch [1/25], lter [111/1752] Loss: 17.9113
Epoch [1/25], lter [121/1752] Loss: 27.5843
Epoch [1/25], lter [131/1752] Loss: 16.4878
Epoch [1/25], lter [141/1752] Loss: 11.9479
Epoch [1/25], lter [151/1752] Loss: 31.8886
Epoch [1/25], lter [161/1752] Loss: 26.3392
Epoch [1/25], lter [171/1752] Loss: 26.2997
Epoch [1/25], lter [181/1752] Loss: 21.3404
Epoch [1/25], lter [191/1752] Loss: 20.3694
Epoch [1/25], lter [201/1752] Loss: 20.1641
Epoch [1/25], lter [211/1752] Loss: 10.2102
Epoch [1/25], lter [221/1752] Loss: 19.1473


  4%|▍         | 1/25 [27:46<11:06:29, 1666.24s/it]

Epoch:  0 | test loss : 13.5076
Epoch [2/25], lter [1/1752] Loss: 9.6151
Epoch [2/25], lter [11/1752] Loss: 17.8819
Epoch [2/25], lter [21/1752] Loss: 36.2848
Epoch [2/25], lter [31/1752] Loss: 17.3159
Epoch [2/25], lter [41/1752] Loss: 39.4571
Epoch [2/25], lter [51/1752] Loss: 5.9503
Epoch [2/25], lter [61/1752] Loss: 21.7455
Epoch [2/25], lter [71/1752] Loss: 16.6650
Epoch [2/25], lter [81/1752] Loss: 12.3730
Epoch [2/25], lter [91/1752] Loss: 21.7604
Epoch [2/25], lter [101/1752] Loss: 21.8858
Epoch [2/25], lter [111/1752] Loss: 23.3653
Epoch [2/25], lter [121/1752] Loss: 19.2023
Epoch [2/25], lter [131/1752] Loss: 17.1474
Epoch [2/25], lter [141/1752] Loss: 17.0753
Epoch [2/25], lter [151/1752] Loss: 9.1430
Epoch [2/25], lter [161/1752] Loss: 9.0247
Epoch [2/25], lter [171/1752] Loss: 15.6967
Epoch [2/25], lter [181/1752] Loss: 16.7071
Epoch [2/25], lter [191/1752] Loss: 21.3785
Epoch [2/25], lter [201/1752] Loss: 9.9391
Epoch [2/25], lter [211/1752] Loss: 24.6614
Epoch [2/25], lt

  8%|▊         | 2/25 [55:39<10:40:14, 1670.21s/it]

Epoch:  1 | test loss : 11.3523
Epoch [3/25], lter [1/1752] Loss: 11.3415
Epoch [3/25], lter [11/1752] Loss: 11.7496
Epoch [3/25], lter [21/1752] Loss: 14.5131
Epoch [3/25], lter [31/1752] Loss: 9.4274
Epoch [3/25], lter [41/1752] Loss: 22.8844
Epoch [3/25], lter [51/1752] Loss: 12.0182
Epoch [3/25], lter [61/1752] Loss: 8.4712
Epoch [3/25], lter [71/1752] Loss: 11.0060
Epoch [3/25], lter [81/1752] Loss: 5.8364
Epoch [3/25], lter [91/1752] Loss: 10.0145
Epoch [3/25], lter [101/1752] Loss: 15.9685
Epoch [3/25], lter [111/1752] Loss: 16.5522
Epoch [3/25], lter [121/1752] Loss: 25.9802
Epoch [3/25], lter [131/1752] Loss: 20.1562
Epoch [3/25], lter [141/1752] Loss: 21.9921
Epoch [3/25], lter [151/1752] Loss: 23.3843
Epoch [3/25], lter [161/1752] Loss: 5.4981
Epoch [3/25], lter [171/1752] Loss: 5.2004
Epoch [3/25], lter [181/1752] Loss: 13.1406
Epoch [3/25], lter [191/1752] Loss: 12.9026
Epoch [3/25], lter [201/1752] Loss: 14.4857
Epoch [3/25], lter [211/1752] Loss: 19.5203
Epoch [3/25], lt

 12%|█▏        | 3/25 [1:23:22<10:11:12, 1666.93s/it]

Epoch:  2 | test loss : 12.7556
Epoch [4/25], lter [1/1752] Loss: 15.7068
Epoch [4/25], lter [11/1752] Loss: 26.2464
Epoch [4/25], lter [21/1752] Loss: 21.8334
Epoch [4/25], lter [31/1752] Loss: 9.6695
Epoch [4/25], lter [41/1752] Loss: 7.6187
Epoch [4/25], lter [51/1752] Loss: 13.9356
Epoch [4/25], lter [61/1752] Loss: 16.1994
Epoch [4/25], lter [71/1752] Loss: 13.7115
Epoch [4/25], lter [81/1752] Loss: 15.7973
Epoch [4/25], lter [91/1752] Loss: 11.3014
Epoch [4/25], lter [101/1752] Loss: 18.5332
Epoch [4/25], lter [111/1752] Loss: 23.3764
Epoch [4/25], lter [121/1752] Loss: 8.7927
Epoch [4/25], lter [131/1752] Loss: 13.4059
Epoch [4/25], lter [141/1752] Loss: 7.8181
Epoch [4/25], lter [151/1752] Loss: 27.2145
Epoch [4/25], lter [161/1752] Loss: 11.3879
Epoch [4/25], lter [171/1752] Loss: 15.1230
Epoch [4/25], lter [181/1752] Loss: 16.6947
Epoch [4/25], lter [191/1752] Loss: 13.2657
Epoch [4/25], lter [201/1752] Loss: 15.3388
Epoch [4/25], lter [211/1752] Loss: 9.8323
Epoch [4/25], lt

 16%|█▌        | 4/25 [1:51:08<9:43:18, 1666.57s/it] 

Epoch:  3 | test loss : 23.4379
Epoch [5/25], lter [1/1752] Loss: 31.6587
Epoch [5/25], lter [11/1752] Loss: 24.8556
Epoch [5/25], lter [21/1752] Loss: 37.6497
Epoch [5/25], lter [31/1752] Loss: 12.4848
Epoch [5/25], lter [41/1752] Loss: 12.2034
Epoch [5/25], lter [51/1752] Loss: 23.2942
Epoch [5/25], lter [61/1752] Loss: 17.9114
Epoch [5/25], lter [71/1752] Loss: 10.0938
Epoch [5/25], lter [81/1752] Loss: 18.5510
Epoch [5/25], lter [91/1752] Loss: 15.9989
Epoch [5/25], lter [101/1752] Loss: 21.0007
Epoch [5/25], lter [111/1752] Loss: 26.4566
Epoch [5/25], lter [121/1752] Loss: 10.6400
Epoch [5/25], lter [131/1752] Loss: 11.7163
Epoch [5/25], lter [141/1752] Loss: 20.3103
Epoch [5/25], lter [151/1752] Loss: 26.3817
Epoch [5/25], lter [161/1752] Loss: 22.2356
Epoch [5/25], lter [171/1752] Loss: 6.7204
Epoch [5/25], lter [181/1752] Loss: 9.3305
Epoch [5/25], lter [191/1752] Loss: 22.9725
Epoch [5/25], lter [201/1752] Loss: 14.0133
Epoch [5/25], lter [211/1752] Loss: 45.3598
Epoch [5/25],

 20%|██        | 5/25 [2:18:40<9:13:47, 1661.39s/it]

Epoch:  4 | test loss : 11.2142
Epoch [6/25], lter [1/1752] Loss: 15.9329
Epoch [6/25], lter [11/1752] Loss: 5.0495
Epoch [6/25], lter [21/1752] Loss: 8.8887
Epoch [6/25], lter [31/1752] Loss: 2.5335
Epoch [6/25], lter [41/1752] Loss: 5.8936
Epoch [6/25], lter [51/1752] Loss: 6.1051
Epoch [6/25], lter [61/1752] Loss: 8.6210
Epoch [6/25], lter [71/1752] Loss: 6.6862
Epoch [6/25], lter [81/1752] Loss: 8.1542
Epoch [6/25], lter [91/1752] Loss: 6.7874
Epoch [6/25], lter [101/1752] Loss: 8.6667
Epoch [6/25], lter [111/1752] Loss: 4.7931
Epoch [6/25], lter [121/1752] Loss: 4.3885
Epoch [6/25], lter [131/1752] Loss: 4.4610
Epoch [6/25], lter [141/1752] Loss: 4.4784
Epoch [6/25], lter [151/1752] Loss: 6.3763
Epoch [6/25], lter [161/1752] Loss: 5.6695
Epoch [6/25], lter [171/1752] Loss: 3.5555
Epoch [6/25], lter [181/1752] Loss: 7.2066
Epoch [6/25], lter [191/1752] Loss: 6.0658
Epoch [6/25], lter [201/1752] Loss: 4.4172
Epoch [6/25], lter [211/1752] Loss: 5.7925
Epoch [6/25], lter [221/1752] Lo

 24%|██▍       | 6/25 [2:46:19<8:45:47, 1660.42s/it]

Epoch:  5 | test loss : 3.5856
Epoch [7/25], lter [1/1752] Loss: 2.2894
Epoch [7/25], lter [11/1752] Loss: 3.4046
Epoch [7/25], lter [21/1752] Loss: 3.4411
Epoch [7/25], lter [31/1752] Loss: 5.8546
Epoch [7/25], lter [41/1752] Loss: 5.2594
Epoch [7/25], lter [51/1752] Loss: 4.8968
Epoch [7/25], lter [61/1752] Loss: 3.4812
Epoch [7/25], lter [71/1752] Loss: 4.2201
Epoch [7/25], lter [81/1752] Loss: 3.7530
Epoch [7/25], lter [91/1752] Loss: 4.0350
Epoch [7/25], lter [101/1752] Loss: 4.0054
Epoch [7/25], lter [111/1752] Loss: 4.2948
Epoch [7/25], lter [121/1752] Loss: 3.3874
Epoch [7/25], lter [131/1752] Loss: 4.3300
Epoch [7/25], lter [141/1752] Loss: 1.5404
Epoch [7/25], lter [151/1752] Loss: 4.1792
Epoch [7/25], lter [161/1752] Loss: 4.4627
Epoch [7/25], lter [171/1752] Loss: 2.8676
Epoch [7/25], lter [181/1752] Loss: 2.5037
Epoch [7/25], lter [191/1752] Loss: 3.0035
Epoch [7/25], lter [201/1752] Loss: 2.7770
Epoch [7/25], lter [211/1752] Loss: 3.1048
Epoch [7/25], lter [221/1752] Loss

 28%|██▊       | 7/25 [3:14:06<8:18:46, 1662.60s/it]

Epoch:  6 | test loss : 3.7164
Epoch [8/25], lter [1/1752] Loss: 3.0450
Epoch [8/25], lter [11/1752] Loss: 2.2605
Epoch [8/25], lter [21/1752] Loss: 2.6103
Epoch [8/25], lter [31/1752] Loss: 2.5583
Epoch [8/25], lter [41/1752] Loss: 3.6326
Epoch [8/25], lter [51/1752] Loss: 2.6165
Epoch [8/25], lter [61/1752] Loss: 4.0550
Epoch [8/25], lter [71/1752] Loss: 2.9033
Epoch [8/25], lter [81/1752] Loss: 2.7243
Epoch [8/25], lter [91/1752] Loss: 3.1669
Epoch [8/25], lter [101/1752] Loss: 1.1630
Epoch [8/25], lter [111/1752] Loss: 5.9655
Epoch [8/25], lter [121/1752] Loss: 1.8493
Epoch [8/25], lter [131/1752] Loss: 2.9484
Epoch [8/25], lter [141/1752] Loss: 1.7464
Epoch [8/25], lter [151/1752] Loss: 3.2501
Epoch [8/25], lter [161/1752] Loss: 4.8610
Epoch [8/25], lter [171/1752] Loss: 2.5318
Epoch [8/25], lter [181/1752] Loss: 2.4235
Epoch [8/25], lter [191/1752] Loss: 2.3583
Epoch [8/25], lter [201/1752] Loss: 2.6038
Epoch [8/25], lter [211/1752] Loss: 2.1255
Epoch [8/25], lter [221/1752] Loss

 32%|███▏      | 8/25 [3:41:54<7:51:33, 1664.30s/it]

Epoch:  7 | test loss : 2.7736
Epoch [9/25], lter [1/1752] Loss: 0.6219
Epoch [9/25], lter [11/1752] Loss: 2.5679
Epoch [9/25], lter [21/1752] Loss: 4.9479
Epoch [9/25], lter [31/1752] Loss: 3.2323
Epoch [9/25], lter [41/1752] Loss: 2.0569
Epoch [9/25], lter [51/1752] Loss: 4.8390
Epoch [9/25], lter [61/1752] Loss: 3.5957
Epoch [9/25], lter [71/1752] Loss: 3.3404
Epoch [9/25], lter [81/1752] Loss: 3.2694
Epoch [9/25], lter [91/1752] Loss: 2.2173
Epoch [9/25], lter [101/1752] Loss: 2.5653
Epoch [9/25], lter [111/1752] Loss: 3.1404
Epoch [9/25], lter [121/1752] Loss: 1.8052
Epoch [9/25], lter [131/1752] Loss: 2.5580
Epoch [9/25], lter [141/1752] Loss: 3.3690
Epoch [9/25], lter [151/1752] Loss: 3.8189
Epoch [9/25], lter [161/1752] Loss: 2.7462
Epoch [9/25], lter [171/1752] Loss: 3.3998
Epoch [9/25], lter [181/1752] Loss: 2.5382
Epoch [9/25], lter [191/1752] Loss: 2.3349
Epoch [9/25], lter [201/1752] Loss: 1.7780
Epoch [9/25], lter [211/1752] Loss: 3.2986
Epoch [9/25], lter [221/1752] Loss

 36%|███▌      | 9/25 [4:09:39<7:23:54, 1664.67s/it]

Epoch:  8 | test loss : 2.4138
Epoch [10/25], lter [1/1752] Loss: 1.6425
Epoch [10/25], lter [11/1752] Loss: 2.2953
Epoch [10/25], lter [21/1752] Loss: 1.5464
Epoch [10/25], lter [31/1752] Loss: 2.3961
Epoch [10/25], lter [41/1752] Loss: 1.7028
Epoch [10/25], lter [51/1752] Loss: 2.4953
Epoch [10/25], lter [61/1752] Loss: 2.1665
Epoch [10/25], lter [71/1752] Loss: 2.7015
Epoch [10/25], lter [81/1752] Loss: 2.0914
Epoch [10/25], lter [91/1752] Loss: 2.5709
Epoch [10/25], lter [101/1752] Loss: 2.1638
Epoch [10/25], lter [111/1752] Loss: 2.8726
Epoch [10/25], lter [121/1752] Loss: 3.4978
Epoch [10/25], lter [131/1752] Loss: 1.6591
Epoch [10/25], lter [141/1752] Loss: 1.3169
Epoch [10/25], lter [151/1752] Loss: 1.8324
Epoch [10/25], lter [161/1752] Loss: 2.1917
Epoch [10/25], lter [171/1752] Loss: 2.2899
Epoch [10/25], lter [181/1752] Loss: 2.1741
Epoch [10/25], lter [191/1752] Loss: 2.4521
Epoch [10/25], lter [201/1752] Loss: 2.4385
Epoch [10/25], lter [211/1752] Loss: 3.6566
Epoch [10/25

 40%|████      | 10/25 [4:37:18<6:55:43, 1662.93s/it]

Epoch:  9 | test loss : 2.4688
Epoch [11/25], lter [1/1752] Loss: 1.4392
Epoch [11/25], lter [11/1752] Loss: 1.8756
Epoch [11/25], lter [21/1752] Loss: 1.2187
Epoch [11/25], lter [31/1752] Loss: 2.0141
Epoch [11/25], lter [41/1752] Loss: 1.3806
Epoch [11/25], lter [51/1752] Loss: 1.4197
Epoch [11/25], lter [61/1752] Loss: 1.5156
Epoch [11/25], lter [71/1752] Loss: 1.6822
Epoch [11/25], lter [81/1752] Loss: 2.5303
Epoch [11/25], lter [91/1752] Loss: 2.0991
Epoch [11/25], lter [101/1752] Loss: 2.1295
Epoch [11/25], lter [111/1752] Loss: 3.1012
Epoch [11/25], lter [121/1752] Loss: 0.9862
Epoch [11/25], lter [131/1752] Loss: 1.7317
Epoch [11/25], lter [141/1752] Loss: 1.6357
Epoch [11/25], lter [151/1752] Loss: 2.6708
Epoch [11/25], lter [161/1752] Loss: 1.4954
Epoch [11/25], lter [171/1752] Loss: 2.8384
Epoch [11/25], lter [181/1752] Loss: 3.8107
Epoch [11/25], lter [191/1752] Loss: 1.1911
Epoch [11/25], lter [201/1752] Loss: 3.2758
Epoch [11/25], lter [211/1752] Loss: 2.2337
Epoch [11/25

 44%|████▍     | 11/25 [5:04:57<6:27:43, 1661.65s/it]

Epoch:  10 | test loss : 2.1175
Epoch [12/25], lter [1/1752] Loss: 2.6825
Epoch [12/25], lter [11/1752] Loss: 3.4733
Epoch [12/25], lter [21/1752] Loss: 2.4469
Epoch [12/25], lter [31/1752] Loss: 1.1776
Epoch [12/25], lter [41/1752] Loss: 2.5851
Epoch [12/25], lter [51/1752] Loss: 1.8988
Epoch [12/25], lter [61/1752] Loss: 2.2054
Epoch [12/25], lter [71/1752] Loss: 1.5257
Epoch [12/25], lter [81/1752] Loss: 2.9154
Epoch [12/25], lter [91/1752] Loss: 2.4129
Epoch [12/25], lter [101/1752] Loss: 2.3918
Epoch [12/25], lter [111/1752] Loss: 1.8724
Epoch [12/25], lter [121/1752] Loss: 1.5117
Epoch [12/25], lter [131/1752] Loss: 1.5148
Epoch [12/25], lter [141/1752] Loss: 1.6490
Epoch [12/25], lter [151/1752] Loss: 2.0336
Epoch [12/25], lter [161/1752] Loss: 1.7148
Epoch [12/25], lter [171/1752] Loss: 2.4529
Epoch [12/25], lter [181/1752] Loss: 2.1978
Epoch [12/25], lter [191/1752] Loss: 1.4461
Epoch [12/25], lter [201/1752] Loss: 1.9514
Epoch [12/25], lter [211/1752] Loss: 2.0758
Epoch [12/2

 48%|████▊     | 12/25 [5:32:27<5:59:14, 1658.04s/it]

Epoch:  11 | test loss : 2.1412
Epoch [13/25], lter [1/1752] Loss: 1.4293
Epoch [13/25], lter [11/1752] Loss: 2.0740
Epoch [13/25], lter [21/1752] Loss: 1.4282
Epoch [13/25], lter [31/1752] Loss: 1.3414
Epoch [13/25], lter [41/1752] Loss: 1.6706
Epoch [13/25], lter [51/1752] Loss: 2.5188
Epoch [13/25], lter [61/1752] Loss: 1.5327
Epoch [13/25], lter [71/1752] Loss: 2.6064
Epoch [13/25], lter [81/1752] Loss: 1.9951
Epoch [13/25], lter [91/1752] Loss: 1.5597
Epoch [13/25], lter [101/1752] Loss: 1.5916
Epoch [13/25], lter [111/1752] Loss: 1.8408
Epoch [13/25], lter [121/1752] Loss: 2.4573
Epoch [13/25], lter [131/1752] Loss: 1.1769
Epoch [13/25], lter [141/1752] Loss: 1.8668
Epoch [13/25], lter [151/1752] Loss: 1.7596
Epoch [13/25], lter [161/1752] Loss: 1.9644
Epoch [13/25], lter [171/1752] Loss: 1.0398
Epoch [13/25], lter [181/1752] Loss: 1.6112
Epoch [13/25], lter [191/1752] Loss: 2.4849
Epoch [13/25], lter [201/1752] Loss: 2.4912
Epoch [13/25], lter [211/1752] Loss: 2.0358
Epoch [13/2

 52%|█████▏    | 13/25 [6:00:02<5:31:26, 1657.21s/it]

Epoch:  12 | test loss : 2.0523
Epoch [14/25], lter [1/1752] Loss: 1.9757
Epoch [14/25], lter [11/1752] Loss: 1.5047
Epoch [14/25], lter [21/1752] Loss: 1.7368
Epoch [14/25], lter [31/1752] Loss: 2.7709
Epoch [14/25], lter [41/1752] Loss: 1.4851
Epoch [14/25], lter [51/1752] Loss: 2.2461
Epoch [14/25], lter [61/1752] Loss: 2.1145
Epoch [14/25], lter [71/1752] Loss: 1.6817
Epoch [14/25], lter [81/1752] Loss: 1.7369
Epoch [14/25], lter [91/1752] Loss: 2.7903
Epoch [14/25], lter [101/1752] Loss: 1.2762
Epoch [14/25], lter [111/1752] Loss: 1.1377
Epoch [14/25], lter [121/1752] Loss: 1.7070
Epoch [14/25], lter [131/1752] Loss: 1.5900
Epoch [14/25], lter [141/1752] Loss: 1.7533
Epoch [14/25], lter [151/1752] Loss: 1.8609
Epoch [14/25], lter [161/1752] Loss: 1.6911
Epoch [14/25], lter [171/1752] Loss: 2.0808
Epoch [14/25], lter [181/1752] Loss: 1.4125
Epoch [14/25], lter [191/1752] Loss: 1.5996
Epoch [14/25], lter [201/1752] Loss: 3.1733
Epoch [14/25], lter [211/1752] Loss: 1.3766
Epoch [14/2

 56%|█████▌    | 14/25 [6:27:52<5:04:31, 1661.05s/it]

Epoch:  13 | test loss : 2.0901
Epoch [15/25], lter [1/1752] Loss: 1.8113
Epoch [15/25], lter [11/1752] Loss: 2.3740
Epoch [15/25], lter [21/1752] Loss: 1.5059
Epoch [15/25], lter [31/1752] Loss: 2.3539
Epoch [15/25], lter [41/1752] Loss: 1.9050
Epoch [15/25], lter [51/1752] Loss: 1.8400
Epoch [15/25], lter [61/1752] Loss: 2.1695
Epoch [15/25], lter [71/1752] Loss: 2.5314
Epoch [15/25], lter [81/1752] Loss: 2.2994
Epoch [15/25], lter [91/1752] Loss: 0.6521
Epoch [15/25], lter [101/1752] Loss: 1.0169
Epoch [15/25], lter [111/1752] Loss: 2.4260
Epoch [15/25], lter [121/1752] Loss: 2.1988
Epoch [15/25], lter [131/1752] Loss: 1.7976
Epoch [15/25], lter [141/1752] Loss: 1.6949
Epoch [15/25], lter [151/1752] Loss: 2.6172
Epoch [15/25], lter [161/1752] Loss: 1.4359
Epoch [15/25], lter [171/1752] Loss: 1.4342
Epoch [15/25], lter [181/1752] Loss: 1.8950
Epoch [15/25], lter [191/1752] Loss: 1.6115
Epoch [15/25], lter [201/1752] Loss: 1.4079
Epoch [15/25], lter [211/1752] Loss: 3.3392
Epoch [15/2

 60%|██████    | 15/25 [6:55:38<4:37:07, 1662.74s/it]

Epoch:  14 | test loss : 2.1318
Epoch [16/25], lter [1/1752] Loss: 1.8435
Epoch [16/25], lter [11/1752] Loss: 2.5388
Epoch [16/25], lter [21/1752] Loss: 2.8006
Epoch [16/25], lter [31/1752] Loss: 2.0061
Epoch [16/25], lter [41/1752] Loss: 2.4766
Epoch [16/25], lter [51/1752] Loss: 2.1963
Epoch [16/25], lter [61/1752] Loss: 1.3648
Epoch [16/25], lter [71/1752] Loss: 1.6593
Epoch [16/25], lter [81/1752] Loss: 1.3678
Epoch [16/25], lter [91/1752] Loss: 1.1937
Epoch [16/25], lter [101/1752] Loss: 1.1776
Epoch [16/25], lter [111/1752] Loss: 1.5723
Epoch [16/25], lter [121/1752] Loss: 1.9086
Epoch [16/25], lter [131/1752] Loss: 2.5357
Epoch [16/25], lter [141/1752] Loss: 1.6029
Epoch [16/25], lter [151/1752] Loss: 1.4886
Epoch [16/25], lter [161/1752] Loss: 2.1470
Epoch [16/25], lter [171/1752] Loss: 1.5724
Epoch [16/25], lter [181/1752] Loss: 1.8747
Epoch [16/25], lter [191/1752] Loss: 1.8575
Epoch [16/25], lter [201/1752] Loss: 1.7548
Epoch [16/25], lter [211/1752] Loss: 1.3347
Epoch [16/2

 64%|██████▍   | 16/25 [7:23:26<4:09:38, 1664.31s/it]

Epoch:  15 | test loss : 2.0290
Epoch [17/25], lter [1/1752] Loss: 1.4518
Epoch [17/25], lter [11/1752] Loss: 1.2254
Epoch [17/25], lter [21/1752] Loss: 2.4643
Epoch [17/25], lter [31/1752] Loss: 2.1230
Epoch [17/25], lter [41/1752] Loss: 2.0345
Epoch [17/25], lter [51/1752] Loss: 1.9611
Epoch [17/25], lter [61/1752] Loss: 1.8283
Epoch [17/25], lter [71/1752] Loss: 1.5546
Epoch [17/25], lter [81/1752] Loss: 2.3799
Epoch [17/25], lter [91/1752] Loss: 2.0195
Epoch [17/25], lter [101/1752] Loss: 1.8152
Epoch [17/25], lter [111/1752] Loss: 2.0757
Epoch [17/25], lter [121/1752] Loss: 1.6019
Epoch [17/25], lter [131/1752] Loss: 2.0731
Epoch [17/25], lter [141/1752] Loss: 2.0737
Epoch [17/25], lter [151/1752] Loss: 3.2919
Epoch [17/25], lter [161/1752] Loss: 1.8200
Epoch [17/25], lter [171/1752] Loss: 2.4129
Epoch [17/25], lter [181/1752] Loss: 2.5304
Epoch [17/25], lter [191/1752] Loss: 2.0942
Epoch [17/25], lter [201/1752] Loss: 1.0744
Epoch [17/25], lter [211/1752] Loss: 1.4256
Epoch [17/2

 68%|██████▊   | 17/25 [7:51:06<3:41:43, 1662.89s/it]

Epoch:  16 | test loss : 2.0298
Epoch [18/25], lter [1/1752] Loss: 1.9908
Epoch [18/25], lter [11/1752] Loss: 2.3680
Epoch [18/25], lter [21/1752] Loss: 2.6582
Epoch [18/25], lter [31/1752] Loss: 1.6634
Epoch [18/25], lter [41/1752] Loss: 2.8304
Epoch [18/25], lter [51/1752] Loss: 2.2399
Epoch [18/25], lter [61/1752] Loss: 1.5801
Epoch [18/25], lter [71/1752] Loss: 2.2891
Epoch [18/25], lter [81/1752] Loss: 1.5584
Epoch [18/25], lter [91/1752] Loss: 2.0572
Epoch [18/25], lter [101/1752] Loss: 1.9538
Epoch [18/25], lter [111/1752] Loss: 1.4722
Epoch [18/25], lter [121/1752] Loss: 2.4986
Epoch [18/25], lter [131/1752] Loss: 1.4630
Epoch [18/25], lter [141/1752] Loss: 1.5420
Epoch [18/25], lter [151/1752] Loss: 3.5695
Epoch [18/25], lter [161/1752] Loss: 1.8826
Epoch [18/25], lter [171/1752] Loss: 0.8437
Epoch [18/25], lter [181/1752] Loss: 2.2206
Epoch [18/25], lter [191/1752] Loss: 1.4740
Epoch [18/25], lter [201/1752] Loss: 2.4864
Epoch [18/25], lter [211/1752] Loss: 1.8044
Epoch [18/2

 72%|███████▏  | 18/25 [8:18:54<3:14:11, 1664.54s/it]

Epoch:  17 | test loss : 2.0599
Epoch [19/25], lter [1/1752] Loss: 1.3404
Epoch [19/25], lter [11/1752] Loss: 1.3470
Epoch [19/25], lter [21/1752] Loss: 2.4759
Epoch [19/25], lter [31/1752] Loss: 1.7066
Epoch [19/25], lter [41/1752] Loss: 1.3960
Epoch [19/25], lter [51/1752] Loss: 1.9836
Epoch [19/25], lter [61/1752] Loss: 1.4416
Epoch [19/25], lter [71/1752] Loss: 1.4216
Epoch [19/25], lter [81/1752] Loss: 1.8755
Epoch [19/25], lter [91/1752] Loss: 1.7556
Epoch [19/25], lter [101/1752] Loss: 2.0165
Epoch [19/25], lter [111/1752] Loss: 1.7584
Epoch [19/25], lter [121/1752] Loss: 1.4639
Epoch [19/25], lter [131/1752] Loss: 1.9509
Epoch [19/25], lter [141/1752] Loss: 2.6247
Epoch [19/25], lter [151/1752] Loss: 1.5346
Epoch [19/25], lter [161/1752] Loss: 2.4759
Epoch [19/25], lter [171/1752] Loss: 2.1612
Epoch [19/25], lter [181/1752] Loss: 2.2167
Epoch [19/25], lter [191/1752] Loss: 2.3640
Epoch [19/25], lter [201/1752] Loss: 1.3779
Epoch [19/25], lter [211/1752] Loss: 2.6379
Epoch [19/2

 76%|███████▌  | 19/25 [8:46:34<2:46:18, 1663.04s/it]

Epoch:  18 | test loss : 2.0422
Epoch [20/25], lter [1/1752] Loss: 1.8804
Epoch [20/25], lter [11/1752] Loss: 2.0869
Epoch [20/25], lter [21/1752] Loss: 2.5402
Epoch [20/25], lter [31/1752] Loss: 1.7951
Epoch [20/25], lter [41/1752] Loss: 0.7974
Epoch [20/25], lter [51/1752] Loss: 1.8843
Epoch [20/25], lter [61/1752] Loss: 1.7619
Epoch [20/25], lter [71/1752] Loss: 3.0356
Epoch [20/25], lter [81/1752] Loss: 1.6636
Epoch [20/25], lter [91/1752] Loss: 2.0275
Epoch [20/25], lter [101/1752] Loss: 1.8778
Epoch [20/25], lter [111/1752] Loss: 2.5321
Epoch [20/25], lter [121/1752] Loss: 1.9221
Epoch [20/25], lter [131/1752] Loss: 1.8535
Epoch [20/25], lter [141/1752] Loss: 1.8642
Epoch [20/25], lter [151/1752] Loss: 1.6932
Epoch [20/25], lter [161/1752] Loss: 1.5065
Epoch [20/25], lter [171/1752] Loss: 1.6539
Epoch [20/25], lter [181/1752] Loss: 2.7577
Epoch [20/25], lter [191/1752] Loss: 1.6482
Epoch [20/25], lter [201/1752] Loss: 1.2778
Epoch [20/25], lter [211/1752] Loss: 1.0706
Epoch [20/2

 80%|████████  | 20/25 [9:14:22<2:18:42, 1664.55s/it]

Epoch:  19 | test loss : 2.0098
Epoch [21/25], lter [1/1752] Loss: 2.1595
Epoch [21/25], lter [11/1752] Loss: 2.0846
Epoch [21/25], lter [21/1752] Loss: 0.8369
Epoch [21/25], lter [31/1752] Loss: 2.1616
Epoch [21/25], lter [41/1752] Loss: 1.9266
Epoch [21/25], lter [51/1752] Loss: 1.2044
Epoch [21/25], lter [61/1752] Loss: 1.8193
Epoch [21/25], lter [71/1752] Loss: 1.9295
Epoch [21/25], lter [81/1752] Loss: 1.4359
Epoch [21/25], lter [91/1752] Loss: 2.4298
Epoch [21/25], lter [101/1752] Loss: 1.5229
Epoch [21/25], lter [111/1752] Loss: 1.6303
Epoch [21/25], lter [121/1752] Loss: 2.4472
Epoch [21/25], lter [131/1752] Loss: 2.4134
Epoch [21/25], lter [141/1752] Loss: 1.4575
Epoch [21/25], lter [151/1752] Loss: 1.7297
Epoch [21/25], lter [161/1752] Loss: 2.1226
Epoch [21/25], lter [171/1752] Loss: 2.4532
Epoch [21/25], lter [181/1752] Loss: 1.2607
Epoch [21/25], lter [191/1752] Loss: 1.3185
Epoch [21/25], lter [201/1752] Loss: 1.8754
Epoch [21/25], lter [211/1752] Loss: 1.0584
Epoch [21/2

 84%|████████▍ | 21/25 [9:42:08<1:51:00, 1665.09s/it]

Epoch:  20 | test loss : 2.0219
Epoch [22/25], lter [1/1752] Loss: 1.8256
Epoch [22/25], lter [11/1752] Loss: 1.3849
Epoch [22/25], lter [21/1752] Loss: 1.6184
Epoch [22/25], lter [31/1752] Loss: 1.5104
Epoch [22/25], lter [41/1752] Loss: 1.7374
Epoch [22/25], lter [51/1752] Loss: 1.2400
Epoch [22/25], lter [61/1752] Loss: 2.2572
Epoch [22/25], lter [71/1752] Loss: 1.8562
Epoch [22/25], lter [81/1752] Loss: 1.4074
Epoch [22/25], lter [91/1752] Loss: 2.1440
Epoch [22/25], lter [101/1752] Loss: 2.1728
Epoch [22/25], lter [111/1752] Loss: 1.8055
Epoch [22/25], lter [121/1752] Loss: 1.8135
Epoch [22/25], lter [131/1752] Loss: 2.0170
Epoch [22/25], lter [141/1752] Loss: 2.3908
Epoch [22/25], lter [151/1752] Loss: 1.5229
Epoch [22/25], lter [161/1752] Loss: 2.7343
Epoch [22/25], lter [171/1752] Loss: 1.5800
Epoch [22/25], lter [181/1752] Loss: 1.2897
Epoch [22/25], lter [191/1752] Loss: 1.3269
Epoch [22/25], lter [201/1752] Loss: 1.6373
Epoch [22/25], lter [211/1752] Loss: 1.9509
Epoch [22/2

 88%|████████▊ | 22/25 [10:09:48<1:23:10, 1663.49s/it]

Epoch:  21 | test loss : 2.0279
Epoch [23/25], lter [1/1752] Loss: 1.8065
Epoch [23/25], lter [11/1752] Loss: 1.8668
Epoch [23/25], lter [21/1752] Loss: 2.1949
Epoch [23/25], lter [31/1752] Loss: 1.4383
Epoch [23/25], lter [41/1752] Loss: 1.9561
Epoch [23/25], lter [51/1752] Loss: 1.7535
Epoch [23/25], lter [61/1752] Loss: 3.2665
Epoch [23/25], lter [71/1752] Loss: 2.2044
Epoch [23/25], lter [81/1752] Loss: 1.9934
Epoch [23/25], lter [91/1752] Loss: 2.3470
Epoch [23/25], lter [101/1752] Loss: 1.7594
Epoch [23/25], lter [111/1752] Loss: 1.4456
Epoch [23/25], lter [121/1752] Loss: 1.0884
Epoch [23/25], lter [131/1752] Loss: 2.4591
Epoch [23/25], lter [141/1752] Loss: 2.7102
Epoch [23/25], lter [151/1752] Loss: 1.2665
Epoch [23/25], lter [161/1752] Loss: 1.8976
Epoch [23/25], lter [171/1752] Loss: 2.2752
Epoch [23/25], lter [181/1752] Loss: 1.2308
Epoch [23/25], lter [191/1752] Loss: 1.6710
Epoch [23/25], lter [201/1752] Loss: 2.1138
Epoch [23/25], lter [211/1752] Loss: 1.8693
Epoch [23/2

 92%|█████████▏| 23/25 [10:37:28<55:25, 1662.53s/it]  

Epoch:  22 | test loss : 2.0318
Epoch [24/25], lter [1/1752] Loss: 1.3431
Epoch [24/25], lter [11/1752] Loss: 1.7369
Epoch [24/25], lter [21/1752] Loss: 2.4375
Epoch [24/25], lter [31/1752] Loss: 1.7923
Epoch [24/25], lter [41/1752] Loss: 1.8515
Epoch [24/25], lter [51/1752] Loss: 2.4525
Epoch [24/25], lter [61/1752] Loss: 2.2459
Epoch [24/25], lter [71/1752] Loss: 2.4314
Epoch [24/25], lter [81/1752] Loss: 1.6011
Epoch [24/25], lter [91/1752] Loss: 1.7558
Epoch [24/25], lter [101/1752] Loss: 2.2122
Epoch [24/25], lter [111/1752] Loss: 1.5766
Epoch [24/25], lter [121/1752] Loss: 1.1077
Epoch [24/25], lter [131/1752] Loss: 1.3303
Epoch [24/25], lter [141/1752] Loss: 2.7778
Epoch [24/25], lter [151/1752] Loss: 2.6302
Epoch [24/25], lter [161/1752] Loss: 1.9314
Epoch [24/25], lter [171/1752] Loss: 1.7987
Epoch [24/25], lter [181/1752] Loss: 1.0478
Epoch [24/25], lter [191/1752] Loss: 1.3958
Epoch [24/25], lter [201/1752] Loss: 1.0508
Epoch [24/25], lter [211/1752] Loss: 1.6010
Epoch [24/2

 96%|█████████▌| 24/25 [11:05:15<27:43, 1663.78s/it]

Epoch:  23 | test loss : 2.0045
Epoch [25/25], lter [1/1752] Loss: 1.4326
Epoch [25/25], lter [11/1752] Loss: 1.2760
Epoch [25/25], lter [21/1752] Loss: 2.5940
Epoch [25/25], lter [31/1752] Loss: 2.2064
Epoch [25/25], lter [41/1752] Loss: 1.7602
Epoch [25/25], lter [51/1752] Loss: 1.3228
Epoch [25/25], lter [61/1752] Loss: 1.5645
Epoch [25/25], lter [71/1752] Loss: 2.0609
Epoch [25/25], lter [81/1752] Loss: 0.9280
Epoch [25/25], lter [91/1752] Loss: 1.7571
Epoch [25/25], lter [101/1752] Loss: 1.9965
Epoch [25/25], lter [111/1752] Loss: 1.3140
Epoch [25/25], lter [121/1752] Loss: 2.0486
Epoch [25/25], lter [131/1752] Loss: 2.4989
Epoch [25/25], lter [141/1752] Loss: 1.9747
Epoch [25/25], lter [151/1752] Loss: 1.4622
Epoch [25/25], lter [161/1752] Loss: 2.0292
Epoch [25/25], lter [171/1752] Loss: 1.9456
Epoch [25/25], lter [181/1752] Loss: 1.7801
Epoch [25/25], lter [191/1752] Loss: 1.9872
Epoch [25/25], lter [201/1752] Loss: 1.7041
Epoch [25/25], lter [211/1752] Loss: 1.2126
Epoch [25/2

100%|██████████| 25/25 [11:32:55<00:00, 1663.04s/it]

Epoch:  24 | test loss : 1.9970





In [9]:
correct_predictions, total_predictions = 0, 0

# Loop over the batches of the test set
for batch_data in dataloaders['train']:
    # Forward pass
    inputs, target = batch_data
    output = model(inputs)

    # Get the predicted class
    _, predicted = torch.max(output.data, 1)

    # Increase the total predictions counter
    total_predictions += target.size(0)

    # Check if the predicted class is equal to the true class
    correct_predictions += (predicted == target).sum().item()

# Calculate the accuracy
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 46.58%


In [10]:
correct_predictions, total_predictions = 0, 0

# Loop over the batches of the test set
for batch_data in dataloaders['test']:
    # Forward pass
    inputs, target = batch_data
    output = model(inputs)

    # Get the predicted class
    _, predicted = torch.max(output.data, 1)

    # Increase the total predictions counter
    total_predictions += target.size(0)

    # Check if the predicted class is equal to the true class
    correct_predictions += (predicted == target).sum().item()

# Calculate the accuracy
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 42.51%


In [11]:
# Save the model
torch.save(model.state_dict(), 'ResNet50_v2.pth')