In [1]:
#!pip install datasets
#!pip install --upgrade --force-reinstall huggingface_hub

In [2]:
import torch
import torchvision
from torchvision.transforms import transforms, Lambda, Resize
from torchvision import datasets, transforms, models
import os
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# Specify where to find the data preparation class
import sys
sys.path.append('../../Data_Preparation')
from Preparation import CustomDataLoader

In [3]:
# ResNet50 training data (ImageNet) properties
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
#DIMENSIONS = 3
#SIZE = 256

LR = 0.01
MOMENTUM=0.9
WEIGHT_DECAY = 1e-4
LR_STEP_SIZE = 5
LR_GAMMA = 0.1

In [4]:
dataloaders = {x: CustomDataLoader(data_path="../../FER2013_Data", batch_size=16, dataset_type=x, mean=MEAN, std=STD, dimensions=3).data_loader for x in ['train', 'test']}
dataset_sizes = {x: len(dataloaders[x]) for x in ['train', 'test']}
print(dataset_sizes)

class_names = dataloaders['train'].dataset.classes
print(class_names)

# Confirm correct data load
print("Train Data Loader:")
for batch_idx, (inputs, labels) in enumerate(dataloaders['train']):
    print("Batch Index:", batch_idx)
    print("Inputs Shape:", inputs.shape)
    print("Labels Shape:", labels.shape)
    # Print the first few labels in the batch
    print("Labels:", labels[:5])
    # Break after printing a few batches
    if batch_idx == 2:
        break

{'train': 1753, 'test': 449}
['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
Train Data Loader:
Batch Index: 0
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([3, 3, 3, 2, 4])
Batch Index: 1
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([6, 3, 5, 3, 5])
Batch Index: 2
Inputs Shape: torch.Size([16, 3, 299, 299])
Labels Shape: torch.Size([16])
Labels: tensor([4, 3, 0, 6, 3])


In [5]:
dataloaders['test']

<torch.utils.data.dataloader.DataLoader at 0x1e704928f40>

In [6]:
# Load the pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)

# Freeze all layers except the final classification layer
for name, param in model.named_parameters():
    if "fc" in name:  # Unfreeze the final classification layer
        param.requires_grad = True
    else:
        param.requires_grad = False

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LR,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)

LR_SCHEDULER = torch.optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEP_SIZE, gamma=LR_GAMMA)


# Move the model to the GPU if available
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = model.to(device)



In [7]:
# num_epochs = 20
# for epoch in tqdm(range(num_epochs)):
#     for phase in ['train', 'test']:
#         if phase == 'train':
#             model.train()
#         else:
#             model.eval()

#         running_loss = 0.0
#         running_corrects = 0

#         for inputs, labels in dataloaders[phase]:
#             inputs = inputs.to(device)
#             labels = labels.to(device)

#             if phase == 'train':

#                 optimizer.zero_grad()

#                 with torch.set_grad_enabled(phase == 'train'):
#                     outputs = model(inputs)
#                     _, preds = torch.max(outputs, 1)
#                     loss = criterion(outputs, labels)

#                     if phase == 'train':
#                         loss.backward()
#                         optimizer.step()

#                 running_loss += loss.item() * inputs.size(0)
#                 running_corrects += torch.sum(preds == labels.data)

#             LR_SCHEDULER.step()

#             elif phase == 'test':
#                 with torch.inference_mode():  
#                     running_loss = 0
#                     for step, (images, labels) in enumerate(dataloaders[phase]):
#                         images, labels = images.to(device), labels.to(device)
                        
#                         outputs = model(images)
#                         loss = criterion(outputs, labels)

#                         running_loss += loss.item()

#     epoch_loss = running_loss / dataset_sizes[phase]
#     epoch_acc = running_corrects.double() / dataset_sizes[phase]

#     print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# print("Training complete!")

In [8]:
num_epochs = 20

print("[INFO] Start training")
print("---------------------")

for epoch in tqdm(range(num_epochs)):

    total_batch = len(dataloaders['train'].dataset)//16
    
    # Training phase
    model.train()  # Set the model to train mode

    # Get statistics
    epoch_loss = 0
    len_dataset = 0
    
    for step, (batch_images, batch_labels) in enumerate(dataloaders['train']):
        X, Y = batch_images.to(device), batch_labels.to(device)

        pred = model(X)
        loss = criterion(pred, Y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += pred.shape[0] * loss.item()
        len_dataset += pred.shape[0]
        if (step) % 10 == 0:
            print('Epoch [%d/%d], lter [%d/%d] Loss: %.4f'
                 %(epoch+1, num_epochs, step+1, total_batch, loss.item()))

    epoch_loss = epoch_loss/ len_dataset
    print('Epoch: ', epoch+1, '| train loss : %0.4f' % epoch_loss)

    LR_SCHEDULER.step()
    
    # Validation phase
    model.eval()  # Set the model to evaluation mode
    with torch.inference_mode():  
        running_loss = 0
        for step, (images, labels) in enumerate(dataloaders['test']):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            
    running_loss = running_loss / len(dataloaders['test'])
    print('Epoch: ', epoch, '| test loss : %0.4f' % running_loss )

[INFO] Start training
---------------------


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch [1/20], lter [1/1752] Loss: 8.5280
Epoch [1/20], lter [11/1752] Loss: 3.5170
Epoch [1/20], lter [21/1752] Loss: 3.0042
Epoch [1/20], lter [31/1752] Loss: 3.0312
Epoch [1/20], lter [41/1752] Loss: 3.0662
Epoch [1/20], lter [51/1752] Loss: 2.0513
Epoch [1/20], lter [61/1752] Loss: 3.3191
Epoch [1/20], lter [71/1752] Loss: 1.6901
Epoch [1/20], lter [81/1752] Loss: 2.7134
Epoch [1/20], lter [91/1752] Loss: 2.0342
Epoch [1/20], lter [101/1752] Loss: 1.5432
Epoch [1/20], lter [111/1752] Loss: 2.2246
Epoch [1/20], lter [121/1752] Loss: 1.9009
Epoch [1/20], lter [131/1752] Loss: 2.3794
Epoch [1/20], lter [141/1752] Loss: 4.2546
Epoch [1/20], lter [151/1752] Loss: 1.4273
Epoch [1/20], lter [161/1752] Loss: 1.3442
Epoch [1/20], lter [171/1752] Loss: 2.9274
Epoch [1/20], lter [181/1752] Loss: 2.4359
Epoch [1/20], lter [191/1752] Loss: 1.7520
Epoch [1/20], lter [201/1752] Loss: 4.1641
Epoch [1/20], lter [211/1752] Loss: 4.7436
Epoch [1/20], lter [221/1752] Loss: 1.8984
Epoch [1/20], lter [23

  5%|▌         | 1/20 [27:39<8:45:35, 1659.77s/it]

Epoch:  0 | test loss : 1.7332
Epoch [2/20], lter [1/1752] Loss: 1.4411
Epoch [2/20], lter [11/1752] Loss: 1.4073
Epoch [2/20], lter [21/1752] Loss: 2.3826
Epoch [2/20], lter [31/1752] Loss: 2.0099
Epoch [2/20], lter [41/1752] Loss: 2.8085
Epoch [2/20], lter [51/1752] Loss: 1.7508
Epoch [2/20], lter [61/1752] Loss: 2.7304
Epoch [2/20], lter [71/1752] Loss: 3.4405
Epoch [2/20], lter [81/1752] Loss: 2.4046
Epoch [2/20], lter [91/1752] Loss: 2.3934
Epoch [2/20], lter [101/1752] Loss: 2.1200
Epoch [2/20], lter [111/1752] Loss: 1.8039
Epoch [2/20], lter [121/1752] Loss: 2.6497
Epoch [2/20], lter [131/1752] Loss: 1.8135
Epoch [2/20], lter [141/1752] Loss: 1.7242
Epoch [2/20], lter [151/1752] Loss: 2.4385
Epoch [2/20], lter [161/1752] Loss: 3.4280
Epoch [2/20], lter [171/1752] Loss: 2.3567
Epoch [2/20], lter [181/1752] Loss: 6.4756
Epoch [2/20], lter [191/1752] Loss: 3.2092
Epoch [2/20], lter [201/1752] Loss: 4.0802
Epoch [2/20], lter [211/1752] Loss: 1.6955
Epoch [2/20], lter [221/1752] Loss

 10%|█         | 2/20 [55:13<8:16:47, 1656.00s/it]

Epoch:  1 | test loss : 2.0004
Epoch [3/20], lter [1/1752] Loss: 1.7494
Epoch [3/20], lter [11/1752] Loss: 2.7658
Epoch [3/20], lter [21/1752] Loss: 2.5255
Epoch [3/20], lter [31/1752] Loss: 2.9630
Epoch [3/20], lter [41/1752] Loss: 1.5594
Epoch [3/20], lter [51/1752] Loss: 1.9542
Epoch [3/20], lter [61/1752] Loss: 2.0910
Epoch [3/20], lter [71/1752] Loss: 1.1358
Epoch [3/20], lter [81/1752] Loss: 2.4522
Epoch [3/20], lter [91/1752] Loss: 2.6869
Epoch [3/20], lter [101/1752] Loss: 1.5905
Epoch [3/20], lter [111/1752] Loss: 1.3063
Epoch [3/20], lter [121/1752] Loss: 3.8071
Epoch [3/20], lter [131/1752] Loss: 1.4648
Epoch [3/20], lter [141/1752] Loss: 2.3015
Epoch [3/20], lter [151/1752] Loss: 2.8030
Epoch [3/20], lter [161/1752] Loss: 3.1934
Epoch [3/20], lter [171/1752] Loss: 2.0444
Epoch [3/20], lter [181/1752] Loss: 2.5125
Epoch [3/20], lter [191/1752] Loss: 1.6654
Epoch [3/20], lter [201/1752] Loss: 1.7732
Epoch [3/20], lter [211/1752] Loss: 2.0738
Epoch [3/20], lter [221/1752] Loss

 15%|█▌        | 3/20 [1:22:50<7:49:20, 1656.52s/it]

Epoch:  2 | test loss : 2.0453
Epoch [4/20], lter [1/1752] Loss: 1.7231
Epoch [4/20], lter [11/1752] Loss: 3.0869
Epoch [4/20], lter [21/1752] Loss: 1.9797
Epoch [4/20], lter [31/1752] Loss: 2.8566
Epoch [4/20], lter [41/1752] Loss: 2.3997
Epoch [4/20], lter [51/1752] Loss: 1.9991
Epoch [4/20], lter [61/1752] Loss: 3.2367
Epoch [4/20], lter [71/1752] Loss: 3.0875
Epoch [4/20], lter [81/1752] Loss: 1.9121
Epoch [4/20], lter [91/1752] Loss: 1.6065
Epoch [4/20], lter [101/1752] Loss: 2.0724
Epoch [4/20], lter [111/1752] Loss: 3.0694
Epoch [4/20], lter [121/1752] Loss: 2.3161
Epoch [4/20], lter [131/1752] Loss: 3.2776
Epoch [4/20], lter [141/1752] Loss: 2.4820
Epoch [4/20], lter [151/1752] Loss: 4.0883
Epoch [4/20], lter [161/1752] Loss: 0.9136
Epoch [4/20], lter [171/1752] Loss: 2.9598
Epoch [4/20], lter [181/1752] Loss: 3.5966
Epoch [4/20], lter [191/1752] Loss: 2.5133
Epoch [4/20], lter [201/1752] Loss: 1.5119
Epoch [4/20], lter [211/1752] Loss: 2.3124
Epoch [4/20], lter [221/1752] Loss

 20%|██        | 4/20 [1:50:29<7:21:59, 1657.45s/it]

Epoch:  3 | test loss : 2.1147
Epoch [5/20], lter [1/1752] Loss: 2.2684
Epoch [5/20], lter [11/1752] Loss: 2.3142
Epoch [5/20], lter [21/1752] Loss: 2.4660
Epoch [5/20], lter [31/1752] Loss: 2.0762
Epoch [5/20], lter [41/1752] Loss: 2.9295
Epoch [5/20], lter [51/1752] Loss: 2.0334
Epoch [5/20], lter [61/1752] Loss: 1.9729
Epoch [5/20], lter [71/1752] Loss: 1.8512
Epoch [5/20], lter [81/1752] Loss: 2.4345
Epoch [5/20], lter [91/1752] Loss: 2.6387
Epoch [5/20], lter [101/1752] Loss: 2.6874
Epoch [5/20], lter [111/1752] Loss: 1.0471
Epoch [5/20], lter [121/1752] Loss: 2.4759
Epoch [5/20], lter [131/1752] Loss: 1.1347
Epoch [5/20], lter [141/1752] Loss: 3.0056
Epoch [5/20], lter [151/1752] Loss: 2.5166
Epoch [5/20], lter [161/1752] Loss: 3.1257
Epoch [5/20], lter [171/1752] Loss: 2.5782
Epoch [5/20], lter [181/1752] Loss: 2.6930
Epoch [5/20], lter [191/1752] Loss: 2.8913
Epoch [5/20], lter [201/1752] Loss: 2.1427
Epoch [5/20], lter [211/1752] Loss: 1.5675
Epoch [5/20], lter [221/1752] Loss

 25%|██▌       | 5/20 [2:18:01<6:53:51, 1655.46s/it]

Epoch:  4 | test loss : 2.2607
Epoch [6/20], lter [1/1752] Loss: 1.8073
Epoch [6/20], lter [11/1752] Loss: 1.5414
Epoch [6/20], lter [21/1752] Loss: 2.0103
Epoch [6/20], lter [31/1752] Loss: 1.8459
Epoch [6/20], lter [41/1752] Loss: 1.6540
Epoch [6/20], lter [51/1752] Loss: 1.1361
Epoch [6/20], lter [61/1752] Loss: 1.4081
Epoch [6/20], lter [71/1752] Loss: 2.2272
Epoch [6/20], lter [81/1752] Loss: 1.6689
Epoch [6/20], lter [91/1752] Loss: 1.5384
Epoch [6/20], lter [101/1752] Loss: 1.6302
Epoch [6/20], lter [111/1752] Loss: 1.4951
Epoch [6/20], lter [121/1752] Loss: 2.0425
Epoch [6/20], lter [131/1752] Loss: 1.5763
Epoch [6/20], lter [141/1752] Loss: 1.6925
Epoch [6/20], lter [151/1752] Loss: 1.1659
Epoch [6/20], lter [161/1752] Loss: 2.1415
Epoch [6/20], lter [171/1752] Loss: 1.5642
Epoch [6/20], lter [181/1752] Loss: 1.4199
Epoch [6/20], lter [191/1752] Loss: 1.6446
Epoch [6/20], lter [201/1752] Loss: 1.5094
Epoch [6/20], lter [211/1752] Loss: 1.5238
Epoch [6/20], lter [221/1752] Loss

 30%|███       | 6/20 [2:45:31<6:25:54, 1653.87s/it]

Epoch:  5 | test loss : 1.5356
Epoch [7/20], lter [1/1752] Loss: 1.6904
Epoch [7/20], lter [11/1752] Loss: 1.7984
Epoch [7/20], lter [21/1752] Loss: 1.7034
Epoch [7/20], lter [31/1752] Loss: 1.6122
Epoch [7/20], lter [41/1752] Loss: 1.8017
Epoch [7/20], lter [51/1752] Loss: 1.5272
Epoch [7/20], lter [61/1752] Loss: 1.4532
Epoch [7/20], lter [71/1752] Loss: 1.8153
Epoch [7/20], lter [81/1752] Loss: 1.8948
Epoch [7/20], lter [91/1752] Loss: 1.8535
Epoch [7/20], lter [101/1752] Loss: 1.4612
Epoch [7/20], lter [111/1752] Loss: 1.4330
Epoch [7/20], lter [121/1752] Loss: 1.9840
Epoch [7/20], lter [131/1752] Loss: 1.6136
Epoch [7/20], lter [141/1752] Loss: 1.2734
Epoch [7/20], lter [151/1752] Loss: 1.5807
Epoch [7/20], lter [161/1752] Loss: 0.8953
Epoch [7/20], lter [171/1752] Loss: 1.3738
Epoch [7/20], lter [181/1752] Loss: 1.4657
Epoch [7/20], lter [191/1752] Loss: 1.9013
Epoch [7/20], lter [201/1752] Loss: 1.1525
Epoch [7/20], lter [211/1752] Loss: 1.6565
Epoch [7/20], lter [221/1752] Loss

 35%|███▌      | 7/20 [3:13:03<5:58:09, 1653.01s/it]

Epoch:  6 | test loss : 1.4883
Epoch [8/20], lter [1/1752] Loss: 1.4237
Epoch [8/20], lter [11/1752] Loss: 1.3853
Epoch [8/20], lter [21/1752] Loss: 1.2973
Epoch [8/20], lter [31/1752] Loss: 2.1422
Epoch [8/20], lter [41/1752] Loss: 1.2826
Epoch [8/20], lter [51/1752] Loss: 1.4732
Epoch [8/20], lter [61/1752] Loss: 1.3780
Epoch [8/20], lter [71/1752] Loss: 1.3074
Epoch [8/20], lter [81/1752] Loss: 1.9452
Epoch [8/20], lter [91/1752] Loss: 1.4990
Epoch [8/20], lter [101/1752] Loss: 1.3839
Epoch [8/20], lter [111/1752] Loss: 1.7800
Epoch [8/20], lter [121/1752] Loss: 1.3993
Epoch [8/20], lter [131/1752] Loss: 1.3655
Epoch [8/20], lter [141/1752] Loss: 1.7330
Epoch [8/20], lter [151/1752] Loss: 1.5287
Epoch [8/20], lter [161/1752] Loss: 1.6063
Epoch [8/20], lter [171/1752] Loss: 1.3165
Epoch [8/20], lter [181/1752] Loss: 2.3169
Epoch [8/20], lter [191/1752] Loss: 0.9322
Epoch [8/20], lter [201/1752] Loss: 1.8502
Epoch [8/20], lter [211/1752] Loss: 1.4844
Epoch [8/20], lter [221/1752] Loss

 40%|████      | 8/20 [3:40:22<5:29:45, 1648.80s/it]

Epoch:  7 | test loss : 1.4879
Epoch [9/20], lter [1/1752] Loss: 1.7687
Epoch [9/20], lter [11/1752] Loss: 1.5134
Epoch [9/20], lter [21/1752] Loss: 1.5503
Epoch [9/20], lter [31/1752] Loss: 1.0972
Epoch [9/20], lter [41/1752] Loss: 1.9204
Epoch [9/20], lter [51/1752] Loss: 1.3990
Epoch [9/20], lter [61/1752] Loss: 1.5018
Epoch [9/20], lter [71/1752] Loss: 1.1367
Epoch [9/20], lter [81/1752] Loss: 1.4943
Epoch [9/20], lter [91/1752] Loss: 1.4425
Epoch [9/20], lter [101/1752] Loss: 1.7879
Epoch [9/20], lter [111/1752] Loss: 1.1923
Epoch [9/20], lter [121/1752] Loss: 1.3105
Epoch [9/20], lter [131/1752] Loss: 1.4668
Epoch [9/20], lter [141/1752] Loss: 1.6279
Epoch [9/20], lter [151/1752] Loss: 1.1832
Epoch [9/20], lter [161/1752] Loss: 1.5628
Epoch [9/20], lter [171/1752] Loss: 1.6241
Epoch [9/20], lter [181/1752] Loss: 1.4455
Epoch [9/20], lter [191/1752] Loss: 1.4198
Epoch [9/20], lter [201/1752] Loss: 1.3823
Epoch [9/20], lter [211/1752] Loss: 1.6949
Epoch [9/20], lter [221/1752] Loss

 45%|████▌     | 9/20 [4:08:00<5:02:48, 1651.71s/it]

Epoch:  8 | test loss : 1.5138
Epoch [10/20], lter [1/1752] Loss: 1.3347
Epoch [10/20], lter [11/1752] Loss: 1.4857
Epoch [10/20], lter [21/1752] Loss: 1.5404
Epoch [10/20], lter [31/1752] Loss: 1.2331
Epoch [10/20], lter [41/1752] Loss: 1.4190
Epoch [10/20], lter [51/1752] Loss: 1.5837
Epoch [10/20], lter [61/1752] Loss: 1.8434
Epoch [10/20], lter [71/1752] Loss: 1.4238
Epoch [10/20], lter [81/1752] Loss: 1.3544
Epoch [10/20], lter [91/1752] Loss: 1.4034
Epoch [10/20], lter [101/1752] Loss: 1.6754
Epoch [10/20], lter [111/1752] Loss: 1.4465
Epoch [10/20], lter [121/1752] Loss: 1.2830
Epoch [10/20], lter [131/1752] Loss: 1.4334
Epoch [10/20], lter [141/1752] Loss: 2.0248
Epoch [10/20], lter [151/1752] Loss: 1.2998
Epoch [10/20], lter [161/1752] Loss: 1.4236
Epoch [10/20], lter [171/1752] Loss: 1.3069
Epoch [10/20], lter [181/1752] Loss: 1.2191
Epoch [10/20], lter [191/1752] Loss: 1.4140
Epoch [10/20], lter [201/1752] Loss: 1.6335
Epoch [10/20], lter [211/1752] Loss: 1.4353
Epoch [10/20

 50%|█████     | 10/20 [4:35:20<4:34:40, 1648.03s/it]

Epoch:  9 | test loss : 1.5075
Epoch [11/20], lter [1/1752] Loss: 1.2542
Epoch [11/20], lter [11/1752] Loss: 1.0517
Epoch [11/20], lter [21/1752] Loss: 1.5243
Epoch [11/20], lter [31/1752] Loss: 1.1899
Epoch [11/20], lter [41/1752] Loss: 1.3044
Epoch [11/20], lter [51/1752] Loss: 1.3496
Epoch [11/20], lter [61/1752] Loss: 1.4184
Epoch [11/20], lter [71/1752] Loss: 1.2503
Epoch [11/20], lter [81/1752] Loss: 1.2327
Epoch [11/20], lter [91/1752] Loss: 1.5011
Epoch [11/20], lter [101/1752] Loss: 1.5291
Epoch [11/20], lter [111/1752] Loss: 1.3353
Epoch [11/20], lter [121/1752] Loss: 1.3776
Epoch [11/20], lter [131/1752] Loss: 1.6273
Epoch [11/20], lter [141/1752] Loss: 1.2846
Epoch [11/20], lter [151/1752] Loss: 1.4432
Epoch [11/20], lter [161/1752] Loss: 1.5245
Epoch [11/20], lter [171/1752] Loss: 1.1881
Epoch [11/20], lter [181/1752] Loss: 1.5633
Epoch [11/20], lter [191/1752] Loss: 1.3115
Epoch [11/20], lter [201/1752] Loss: 1.9413
Epoch [11/20], lter [211/1752] Loss: 1.8908
Epoch [11/20

 55%|█████▌    | 11/20 [5:02:52<4:07:23, 1649.26s/it]

Epoch:  10 | test loss : 1.4433
Epoch [12/20], lter [1/1752] Loss: 1.5630
Epoch [12/20], lter [11/1752] Loss: 1.3627
Epoch [12/20], lter [21/1752] Loss: 1.3826
Epoch [12/20], lter [31/1752] Loss: 1.4422
Epoch [12/20], lter [41/1752] Loss: 1.1056
Epoch [12/20], lter [51/1752] Loss: 1.3289
Epoch [12/20], lter [61/1752] Loss: 1.3091
Epoch [12/20], lter [71/1752] Loss: 1.6513
Epoch [12/20], lter [81/1752] Loss: 1.5449
Epoch [12/20], lter [91/1752] Loss: 1.4414
Epoch [12/20], lter [101/1752] Loss: 1.4439
Epoch [12/20], lter [111/1752] Loss: 1.2602
Epoch [12/20], lter [121/1752] Loss: 1.4182
Epoch [12/20], lter [131/1752] Loss: 1.3743
Epoch [12/20], lter [141/1752] Loss: 1.2651
Epoch [12/20], lter [151/1752] Loss: 1.7398
Epoch [12/20], lter [161/1752] Loss: 1.3825
Epoch [12/20], lter [171/1752] Loss: 1.4940
Epoch [12/20], lter [181/1752] Loss: 1.2469
Epoch [12/20], lter [191/1752] Loss: 1.5586
Epoch [12/20], lter [201/1752] Loss: 1.1172
Epoch [12/20], lter [211/1752] Loss: 1.5851
Epoch [12/2

 60%|██████    | 12/20 [5:30:24<3:39:59, 1649.99s/it]

Epoch:  11 | test loss : 1.4337
Epoch [13/20], lter [1/1752] Loss: 1.3035
Epoch [13/20], lter [11/1752] Loss: 1.3322
Epoch [13/20], lter [21/1752] Loss: 1.6000
Epoch [13/20], lter [31/1752] Loss: 1.7164
Epoch [13/20], lter [41/1752] Loss: 1.4380
Epoch [13/20], lter [51/1752] Loss: 1.4902
Epoch [13/20], lter [61/1752] Loss: 1.3440
Epoch [13/20], lter [71/1752] Loss: 1.5557
Epoch [13/20], lter [81/1752] Loss: 0.9779
Epoch [13/20], lter [91/1752] Loss: 1.0201
Epoch [13/20], lter [101/1752] Loss: 1.8935
Epoch [13/20], lter [111/1752] Loss: 1.1868
Epoch [13/20], lter [121/1752] Loss: 1.4374
Epoch [13/20], lter [131/1752] Loss: 1.6323
Epoch [13/20], lter [141/1752] Loss: 1.6182
Epoch [13/20], lter [151/1752] Loss: 1.6590
Epoch [13/20], lter [161/1752] Loss: 1.5845
Epoch [13/20], lter [171/1752] Loss: 1.3138
Epoch [13/20], lter [181/1752] Loss: 1.0034
Epoch [13/20], lter [191/1752] Loss: 1.6694
Epoch [13/20], lter [201/1752] Loss: 1.2328
Epoch [13/20], lter [211/1752] Loss: 1.9808
Epoch [13/2

 65%|██████▌   | 13/20 [5:57:54<3:12:31, 1650.15s/it]

Epoch:  12 | test loss : 1.4373
Epoch [14/20], lter [1/1752] Loss: 1.4823
Epoch [14/20], lter [11/1752] Loss: 1.3764
Epoch [14/20], lter [21/1752] Loss: 1.9709
Epoch [14/20], lter [31/1752] Loss: 0.9239
Epoch [14/20], lter [41/1752] Loss: 1.2332
Epoch [14/20], lter [51/1752] Loss: 1.7556
Epoch [14/20], lter [61/1752] Loss: 1.7975
Epoch [14/20], lter [71/1752] Loss: 1.4114
Epoch [14/20], lter [81/1752] Loss: 1.4372
Epoch [14/20], lter [91/1752] Loss: 1.3363
Epoch [14/20], lter [101/1752] Loss: 1.2002
Epoch [14/20], lter [111/1752] Loss: 1.5174
Epoch [14/20], lter [121/1752] Loss: 1.3606
Epoch [14/20], lter [131/1752] Loss: 1.2847
Epoch [14/20], lter [141/1752] Loss: 1.5000
Epoch [14/20], lter [151/1752] Loss: 2.0104
Epoch [14/20], lter [161/1752] Loss: 1.1106
Epoch [14/20], lter [171/1752] Loss: 0.8933
Epoch [14/20], lter [181/1752] Loss: 1.5567
Epoch [14/20], lter [191/1752] Loss: 1.4632
Epoch [14/20], lter [201/1752] Loss: 1.6100
Epoch [14/20], lter [211/1752] Loss: 1.4797
Epoch [14/2

 70%|███████   | 14/20 [6:25:32<2:45:14, 1652.45s/it]

Epoch:  13 | test loss : 1.4417
Epoch [15/20], lter [1/1752] Loss: 1.2642
Epoch [15/20], lter [11/1752] Loss: 1.4911
Epoch [15/20], lter [21/1752] Loss: 1.4647
Epoch [15/20], lter [31/1752] Loss: 1.7806
Epoch [15/20], lter [41/1752] Loss: 1.7370
Epoch [15/20], lter [51/1752] Loss: 1.4871
Epoch [15/20], lter [61/1752] Loss: 1.7134
Epoch [15/20], lter [71/1752] Loss: 1.6625
Epoch [15/20], lter [81/1752] Loss: 1.4230
Epoch [15/20], lter [91/1752] Loss: 1.4043
Epoch [15/20], lter [101/1752] Loss: 1.4650
Epoch [15/20], lter [111/1752] Loss: 1.3678
Epoch [15/20], lter [121/1752] Loss: 1.6965
Epoch [15/20], lter [131/1752] Loss: 1.6359
Epoch [15/20], lter [141/1752] Loss: 1.1467
Epoch [15/20], lter [151/1752] Loss: 1.2355
Epoch [15/20], lter [161/1752] Loss: 1.4697
Epoch [15/20], lter [171/1752] Loss: 1.0879
Epoch [15/20], lter [181/1752] Loss: 1.2018
Epoch [15/20], lter [191/1752] Loss: 1.3424
Epoch [15/20], lter [201/1752] Loss: 1.4852
Epoch [15/20], lter [211/1752] Loss: 1.2341
Epoch [15/2

 75%|███████▌  | 15/20 [6:53:13<2:17:54, 1654.99s/it]

Epoch:  14 | test loss : 1.4381
Epoch [16/20], lter [1/1752] Loss: 1.5716
Epoch [16/20], lter [11/1752] Loss: 1.0122
Epoch [16/20], lter [21/1752] Loss: 0.9900
Epoch [16/20], lter [31/1752] Loss: 1.4060
Epoch [16/20], lter [41/1752] Loss: 1.3690
Epoch [16/20], lter [51/1752] Loss: 1.0861
Epoch [16/20], lter [61/1752] Loss: 1.3070
Epoch [16/20], lter [71/1752] Loss: 1.6686
Epoch [16/20], lter [81/1752] Loss: 1.4530
Epoch [16/20], lter [91/1752] Loss: 1.5644
Epoch [16/20], lter [101/1752] Loss: 1.6991
Epoch [16/20], lter [111/1752] Loss: 1.4487
Epoch [16/20], lter [121/1752] Loss: 1.2607
Epoch [16/20], lter [131/1752] Loss: 1.4099
Epoch [16/20], lter [141/1752] Loss: 1.3986
Epoch [16/20], lter [151/1752] Loss: 1.4560
Epoch [16/20], lter [161/1752] Loss: 1.1736
Epoch [16/20], lter [171/1752] Loss: 1.3361
Epoch [16/20], lter [181/1752] Loss: 1.5000
Epoch [16/20], lter [191/1752] Loss: 1.1553
Epoch [16/20], lter [201/1752] Loss: 1.0837
Epoch [16/20], lter [211/1752] Loss: 1.3594
Epoch [16/2

 80%|████████  | 16/20 [7:20:50<1:50:22, 1655.60s/it]

Epoch:  15 | test loss : 1.4363
Epoch [17/20], lter [1/1752] Loss: 1.4423
Epoch [17/20], lter [11/1752] Loss: 1.4875
Epoch [17/20], lter [21/1752] Loss: 1.2518
Epoch [17/20], lter [31/1752] Loss: 1.3471
Epoch [17/20], lter [41/1752] Loss: 1.4863
Epoch [17/20], lter [51/1752] Loss: 1.3663
Epoch [17/20], lter [61/1752] Loss: 1.4036
Epoch [17/20], lter [71/1752] Loss: 1.4893
Epoch [17/20], lter [81/1752] Loss: 1.5603
Epoch [17/20], lter [91/1752] Loss: 1.3671
Epoch [17/20], lter [101/1752] Loss: 1.5632
Epoch [17/20], lter [111/1752] Loss: 1.3500
Epoch [17/20], lter [121/1752] Loss: 1.8198
Epoch [17/20], lter [131/1752] Loss: 1.3957
Epoch [17/20], lter [141/1752] Loss: 1.3220
Epoch [17/20], lter [151/1752] Loss: 1.3639
Epoch [17/20], lter [161/1752] Loss: 1.4554
Epoch [17/20], lter [171/1752] Loss: 1.3634
Epoch [17/20], lter [181/1752] Loss: 1.4628
Epoch [17/20], lter [191/1752] Loss: 1.6201
Epoch [17/20], lter [201/1752] Loss: 1.3941
Epoch [17/20], lter [211/1752] Loss: 1.4886
Epoch [17/2

 85%|████████▌ | 17/20 [7:48:15<1:22:37, 1652.50s/it]

Epoch:  16 | test loss : 1.4466
Epoch [18/20], lter [1/1752] Loss: 1.5268
Epoch [18/20], lter [11/1752] Loss: 1.4538
Epoch [18/20], lter [21/1752] Loss: 1.2379
Epoch [18/20], lter [31/1752] Loss: 1.3260
Epoch [18/20], lter [41/1752] Loss: 1.7801
Epoch [18/20], lter [51/1752] Loss: 1.4099
Epoch [18/20], lter [61/1752] Loss: 1.7270
Epoch [18/20], lter [71/1752] Loss: 1.3519
Epoch [18/20], lter [81/1752] Loss: 1.3998
Epoch [18/20], lter [91/1752] Loss: 1.3316
Epoch [18/20], lter [101/1752] Loss: 1.3153
Epoch [18/20], lter [111/1752] Loss: 1.3222
Epoch [18/20], lter [121/1752] Loss: 1.6570
Epoch [18/20], lter [131/1752] Loss: 1.2887
Epoch [18/20], lter [141/1752] Loss: 1.1761
Epoch [18/20], lter [151/1752] Loss: 1.6436
Epoch [18/20], lter [161/1752] Loss: 1.4852
Epoch [18/20], lter [171/1752] Loss: 1.4826
Epoch [18/20], lter [181/1752] Loss: 1.4549
Epoch [18/20], lter [191/1752] Loss: 1.5461
Epoch [18/20], lter [201/1752] Loss: 1.3878
Epoch [18/20], lter [211/1752] Loss: 1.2691
Epoch [18/2

 90%|█████████ | 18/20 [8:15:46<55:03, 1651.79s/it]  

Epoch:  17 | test loss : 1.4314
Epoch [19/20], lter [1/1752] Loss: 1.4007
Epoch [19/20], lter [11/1752] Loss: 1.3602
Epoch [19/20], lter [21/1752] Loss: 1.2583
Epoch [19/20], lter [31/1752] Loss: 1.2139
Epoch [19/20], lter [41/1752] Loss: 1.5530
Epoch [19/20], lter [51/1752] Loss: 1.4511
Epoch [19/20], lter [61/1752] Loss: 1.5383
Epoch [19/20], lter [71/1752] Loss: 1.3878
Epoch [19/20], lter [81/1752] Loss: 1.5036
Epoch [19/20], lter [91/1752] Loss: 1.6198
Epoch [19/20], lter [101/1752] Loss: 1.5806
Epoch [19/20], lter [111/1752] Loss: 0.9200
Epoch [19/20], lter [121/1752] Loss: 1.3399
Epoch [19/20], lter [131/1752] Loss: 1.6258
Epoch [19/20], lter [141/1752] Loss: 1.3706
Epoch [19/20], lter [151/1752] Loss: 1.7625
Epoch [19/20], lter [161/1752] Loss: 1.7452
Epoch [19/20], lter [171/1752] Loss: 1.3867
Epoch [19/20], lter [181/1752] Loss: 1.3914
Epoch [19/20], lter [191/1752] Loss: 1.5240
Epoch [19/20], lter [201/1752] Loss: 1.6064
Epoch [19/20], lter [211/1752] Loss: 1.5161
Epoch [19/2

 95%|█████████▌| 19/20 [8:43:17<27:31, 1651.61s/it]

Epoch:  18 | test loss : 1.4195
Epoch [20/20], lter [1/1752] Loss: 1.5243
Epoch [20/20], lter [11/1752] Loss: 1.3079
Epoch [20/20], lter [21/1752] Loss: 1.5430
Epoch [20/20], lter [31/1752] Loss: 1.0697
Epoch [20/20], lter [41/1752] Loss: 0.9020
Epoch [20/20], lter [51/1752] Loss: 1.8747
Epoch [20/20], lter [61/1752] Loss: 1.5185
Epoch [20/20], lter [71/1752] Loss: 1.0407
Epoch [20/20], lter [81/1752] Loss: 1.4169
Epoch [20/20], lter [91/1752] Loss: 1.4658
Epoch [20/20], lter [101/1752] Loss: 1.3180
Epoch [20/20], lter [111/1752] Loss: 1.1695
Epoch [20/20], lter [121/1752] Loss: 0.8459
Epoch [20/20], lter [131/1752] Loss: 1.4879
Epoch [20/20], lter [141/1752] Loss: 1.4309
Epoch [20/20], lter [151/1752] Loss: 2.0136
Epoch [20/20], lter [161/1752] Loss: 1.4160
Epoch [20/20], lter [171/1752] Loss: 1.1937
Epoch [20/20], lter [181/1752] Loss: 1.1949
Epoch [20/20], lter [191/1752] Loss: 1.5975
Epoch [20/20], lter [201/1752] Loss: 1.9980
Epoch [20/20], lter [211/1752] Loss: 1.1770
Epoch [20/2

100%|██████████| 20/20 [9:10:54<00:00, 1652.74s/it]

Epoch:  19 | test loss : 1.4339





In [9]:
correct_predictions, total_predictions = 0, 0

# Loop over the batches of the test set
for batch_data in dataloaders['train']:
    # Forward pass
    inputs, target = batch_data
    output = model(inputs)

    # Get the predicted class
    _, predicted = torch.max(output.data, 1)

    # Increase the total predictions counter
    total_predictions += target.size(0)

    # Check if the predicted class is equal to the true class
    correct_predictions += (predicted == target).sum().item()

# Calculate the accuracy
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 47.68%


In [10]:
correct_predictions, total_predictions = 0, 0

# Loop over the batches of the test set
for batch_data in dataloaders['test']:
    # Forward pass
    inputs, target = batch_data
    output = model(inputs)

    # Get the predicted class
    _, predicted = torch.max(output.data, 1)

    # Increase the total predictions counter
    total_predictions += target.size(0)

    # Check if the predicted class is equal to the true class
    correct_predictions += (predicted == target).sum().item()

# Calculate the accuracy
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 44.11%


In [11]:
# Save the model
torch.save(model.state_dict(), 'ResNet50.pth')