In [30]:
%cd ..

/home/ltorres/leo/tesis/cloud-classification


In [31]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
from torchvision.io import read_image
import glob
import os
import math


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from random import sample
import torch
import torch.nn as nn
from torch.optim import lr_scheduler

In [33]:
from src.dataset import GCD
from src import config
from src import engine_gnn as engine
from src import utils

#from src.models.graph_nets import GraphConvGNN, GATConvGNN
from src.models.old.initial_graphnets import GATConvGNN

In [34]:
from sklearn.metrics import accuracy_score

In [35]:
import wandb

In [36]:
wandb.login()

True

In [37]:
device = 'cuda:1'

In [38]:
LR = 2e-4
EPOCHS = 100

#### Data loaders

In [39]:
path_train_images = utils.get_gcd_paths(config.DATA_DIR,'train')

In [41]:
len(path_train_images)

10000

In [42]:
train_dataset = GCD(path_train_images, resize=256, aug_types='aug')

In [11]:
train_dataset = GCD(path_train_images, resize=256, aug_types='aug')

train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        num_workers=4,
        shuffle=True,
    )

In [12]:
path_test_images = utils.get_gcd_paths(config.DATA_DIR,'test')

test_dataset = GCD(path_test_images, resize=256)

test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        num_workers=4,
        shuffle=False,
    )

### WANDB config file

In [13]:
exp_name = '6_GATConvGNN_LOGE_SGD_01_04_22'

In [14]:
wandb.init(
    # Set the project where this run will be logged
    project="cloud classification",
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=exp_name, 
    # Track hyperparameters and run metadata
    config={
    "learning_rate": LR,
    "architecture": "GATConvGNN",
    "loss": "LogeLoss",
    "optim": "SGD",
    "dataset": "GCD",
    "epochs": config.EPOCHS,
  })

#### Model

In [15]:
from torchvision import models

In [16]:
torch.cuda.empty_cache()

In [None]:
# model  = utils.build_model_gatconv(
#                                     7, #GCD num classes
#                                     512,
#                                     3,
#                                     4,
#                                     0.75,
#                                     device
#                                    )
model = GATConvGNN(7).to(device)

NameError: name 'GATConvGNN' is not defined

In [18]:
criterion = nn.CrossEntropyLoss()
#criterion = utils.loge_loss
#optimizer  = torch.optim.Adam(model.parameters(), lr=3e-4)
optimizer  = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9)

scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [19]:
model_filename = f'/gcd_gnn/{exp_name}.pt'

In [20]:
class_mapper = {
    0: '1_cumulus',
    1: '2_altocumulus',
    2: '3_cirrus',
    3: '4_clearsky',
    4: '5_stratocumulus',
    5: '6_cumulonimbus',
    6: '7_mixed',
}

In [21]:
classes = list(class_mapper.values())

In [22]:
classes

['1_cumulus',
 '2_altocumulus',
 '3_cirrus',
 '4_clearsky',
 '5_stratocumulus',
 '6_cumulonimbus',
 '7_mixed']

### Train

In [None]:
best_accuracy=0
best_iteration=0
best_params={}

for e in range(EPOCHS):
    ### TRAIN DATASET
    preds, targets, loss = engine.train_fn(model, train_loader, criterion, optimizer, device=device)
    train_acc = accuracy_score(targets, preds)
    
    scheduler.step()
    
    test_preds, test_targets, test_loss = engine.eval_fn(model, test_loader, criterion, device=device)
    test_acc = accuracy_score(test_targets, test_preds)
    
    if test_acc > best_accuracy:
        torch.save(model.state_dict(), config.SAVE_PATH+model_filename)
        print(f"Saved best parameters at epoch {e+1}")
        best_accuracy = test_acc
        best_iteration = e+1
        best_params = model.state_dict()
    
    print("EPOCH {}: Train acc: {:.2%} Train Loss: {:.4f} Test acc: {:.2%} Test Loss: {:.4f}".format(
        e+1,
        train_acc,
        loss,
        test_acc,
        test_loss
    ))
    
    metrics = {
                "train/train_loss": loss,
                "train/train_accuracy": train_acc,
                "test/test_loss": test_loss,
                "test/test_accuracy": test_acc,
              }

    wandb.log(metrics)
    
# Load best model paramns    
model.load_state_dict(torch.load(config.SAVE_PATH+model_filename))

test_preds, test_targets, test_loss = engine.eval_fn(model, test_loader, criterion, device=device)
wandb.summary['test_accuracy'] = accuracy_score(test_targets, test_preds)
wandb.summary['test_loss'] = test_loss

wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None,
                            preds=test_preds, y_true=test_targets,
                            class_names=classes)})

wandb.finish()

100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 14.92it/s]


Saved best parameters at epoch 1
EPOCH 1: Train acc: 70.28% Train Loss: 1.7579 Test acc: 73.42% Test Loss: 1.6797


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.44it/s]


Saved best parameters at epoch 2
EPOCH 2: Train acc: 84.87% Train Loss: 0.8457 Test acc: 76.84% Test Loss: 1.6523


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:18<00:00, 15.18it/s]


EPOCH 3: Train acc: 87.62% Train Loss: 0.6878 Test acc: 76.12% Test Loss: 1.7435


100%|██████████| 313/313 [00:57<00:00,  5.48it/s]
100%|██████████| 282/282 [00:18<00:00, 15.45it/s]


Saved best parameters at epoch 4
EPOCH 4: Train acc: 88.37% Train Loss: 0.6280 Test acc: 77.48% Test Loss: 1.5162


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.09it/s]


Saved best parameters at epoch 5
EPOCH 5: Train acc: 89.30% Train Loss: 0.5652 Test acc: 78.47% Test Loss: 1.5051


100%|██████████| 313/313 [00:57<00:00,  5.48it/s]
100%|██████████| 282/282 [00:18<00:00, 15.27it/s]


Saved best parameters at epoch 6
EPOCH 6: Train acc: 90.25% Train Loss: 0.5295 Test acc: 78.54% Test Loss: 1.4961


100%|██████████| 313/313 [00:57<00:00,  5.46it/s]
100%|██████████| 282/282 [00:18<00:00, 15.21it/s]


EPOCH 7: Train acc: 91.02% Train Loss: 0.4970 Test acc: 78.42% Test Loss: 1.5368


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.17it/s]


EPOCH 8: Train acc: 91.41% Train Loss: 0.4600 Test acc: 77.56% Test Loss: 1.6749


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.32it/s]


EPOCH 9: Train acc: 91.75% Train Loss: 0.4319 Test acc: 77.80% Test Loss: 1.5963


100%|██████████| 313/313 [00:57<00:00,  5.46it/s]
100%|██████████| 282/282 [00:19<00:00, 14.79it/s]


EPOCH 10: Train acc: 92.45% Train Loss: 0.4083 Test acc: 78.49% Test Loss: 1.5561


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:20<00:00, 14.08it/s]


EPOCH 11: Train acc: 93.27% Train Loss: 0.3661 Test acc: 78.32% Test Loss: 1.6791


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.10it/s]


Saved best parameters at epoch 12
EPOCH 12: Train acc: 94.10% Train Loss: 0.3324 Test acc: 79.04% Test Loss: 1.5861


100%|██████████| 313/313 [00:57<00:00,  5.48it/s]
100%|██████████| 282/282 [00:18<00:00, 15.49it/s]


EPOCH 13: Train acc: 93.92% Train Loss: 0.3409 Test acc: 78.36% Test Loss: 1.7284


100%|██████████| 313/313 [00:57<00:00,  5.48it/s]
100%|██████████| 282/282 [00:19<00:00, 14.82it/s]


EPOCH 14: Train acc: 93.91% Train Loss: 0.3399 Test acc: 78.52% Test Loss: 1.6561


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.25it/s]


EPOCH 15: Train acc: 94.13% Train Loss: 0.3222 Test acc: 78.71% Test Loss: 1.6918


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.26it/s]


Saved best parameters at epoch 16
EPOCH 16: Train acc: 94.15% Train Loss: 0.3299 Test acc: 79.30% Test Loss: 1.5880


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.51it/s]


EPOCH 17: Train acc: 93.78% Train Loss: 0.3398 Test acc: 77.64% Test Loss: 1.7740


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:19<00:00, 14.81it/s]


EPOCH 18: Train acc: 94.00% Train Loss: 0.3274 Test acc: 78.96% Test Loss: 1.6144


100%|██████████| 313/313 [00:57<00:00,  5.48it/s]
100%|██████████| 282/282 [00:18<00:00, 15.56it/s]


EPOCH 19: Train acc: 93.91% Train Loss: 0.3357 Test acc: 79.07% Test Loss: 1.5912


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.10it/s]


EPOCH 20: Train acc: 94.45% Train Loss: 0.3238 Test acc: 78.46% Test Loss: 1.7225


100%|██████████| 313/313 [00:57<00:00,  5.46it/s]
100%|██████████| 282/282 [00:18<00:00, 15.56it/s]


EPOCH 21: Train acc: 94.43% Train Loss: 0.3034 Test acc: 79.20% Test Loss: 1.6110


100%|██████████| 313/313 [00:57<00:00,  5.47it/s]
100%|██████████| 282/282 [00:18<00:00, 15.13it/s]


EPOCH 22: Train acc: 94.51% Train Loss: 0.3131 Test acc: 78.90% Test Loss: 1.6333


100%|██████████| 313/313 [00:57<00:00,  5.46it/s]
100%|██████████| 282/282 [00:17<00:00, 15.70it/s]


EPOCH 23: Train acc: 94.46% Train Loss: 0.2977 Test acc: 78.91% Test Loss: 1.6390


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.69it/s]


Saved best parameters at epoch 24
EPOCH 24: Train acc: 94.67% Train Loss: 0.2988 Test acc: 79.37% Test Loss: 1.5921


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.66it/s]


EPOCH 25: Train acc: 94.39% Train Loss: 0.3050 Test acc: 78.77% Test Loss: 1.6519


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.74it/s]


EPOCH 26: Train acc: 94.35% Train Loss: 0.3142 Test acc: 78.57% Test Loss: 1.7362


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.63it/s]


EPOCH 27: Train acc: 94.33% Train Loss: 0.3142 Test acc: 79.22% Test Loss: 1.5952


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:18<00:00, 15.07it/s]


EPOCH 28: Train acc: 94.63% Train Loss: 0.2960 Test acc: 78.80% Test Loss: 1.6331


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 15.59it/s]


EPOCH 29: Train acc: 94.50% Train Loss: 0.3049 Test acc: 78.62% Test Loss: 1.6586


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.76it/s]


EPOCH 30: Train acc: 94.69% Train Loss: 0.2954 Test acc: 78.67% Test Loss: 1.6509


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.66it/s]


EPOCH 31: Train acc: 94.61% Train Loss: 0.2922 Test acc: 78.54% Test Loss: 1.6851


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.72it/s]


EPOCH 32: Train acc: 94.46% Train Loss: 0.3055 Test acc: 78.90% Test Loss: 1.6338


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.80it/s]


EPOCH 33: Train acc: 94.25% Train Loss: 0.3135 Test acc: 78.37% Test Loss: 1.7161


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.64it/s]


EPOCH 34: Train acc: 94.43% Train Loss: 0.3028 Test acc: 78.96% Test Loss: 1.6148


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.79it/s]


EPOCH 35: Train acc: 94.47% Train Loss: 0.3054 Test acc: 78.21% Test Loss: 1.7832


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.66it/s]


EPOCH 36: Train acc: 94.49% Train Loss: 0.3038 Test acc: 79.00% Test Loss: 1.5995


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.68it/s]


EPOCH 37: Train acc: 94.36% Train Loss: 0.3018 Test acc: 78.27% Test Loss: 1.7147


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 15.63it/s]


EPOCH 38: Train acc: 94.58% Train Loss: 0.3040 Test acc: 78.57% Test Loss: 1.6760


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 15.61it/s]


EPOCH 39: Train acc: 94.68% Train Loss: 0.3007 Test acc: 78.77% Test Loss: 1.6307


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.69it/s]


EPOCH 40: Train acc: 94.33% Train Loss: 0.3205 Test acc: 79.08% Test Loss: 1.5721


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.79it/s]


EPOCH 41: Train acc: 94.66% Train Loss: 0.3067 Test acc: 78.74% Test Loss: 1.6588


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.70it/s]


EPOCH 42: Train acc: 94.65% Train Loss: 0.3038 Test acc: 78.76% Test Loss: 1.6301


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.74it/s]


EPOCH 43: Train acc: 94.38% Train Loss: 0.3122 Test acc: 78.37% Test Loss: 1.7089


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.69it/s]


EPOCH 44: Train acc: 94.74% Train Loss: 0.3041 Test acc: 79.33% Test Loss: 1.5600


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 15.63it/s]


EPOCH 45: Train acc: 94.32% Train Loss: 0.3065 Test acc: 79.17% Test Loss: 1.6057


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 15.64it/s]


EPOCH 46: Train acc: 94.79% Train Loss: 0.2950 Test acc: 78.34% Test Loss: 1.6863


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.73it/s]


EPOCH 47: Train acc: 94.32% Train Loss: 0.3134 Test acc: 78.89% Test Loss: 1.6263


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 15.57it/s]


EPOCH 48: Train acc: 94.47% Train Loss: 0.2991 Test acc: 78.91% Test Loss: 1.6795


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.69it/s]


EPOCH 49: Train acc: 94.28% Train Loss: 0.3083 Test acc: 78.63% Test Loss: 1.6693


100%|██████████| 313/313 [00:57<00:00,  5.46it/s]
100%|██████████| 282/282 [00:17<00:00, 15.76it/s]


EPOCH 50: Train acc: 94.35% Train Loss: 0.3107 Test acc: 78.79% Test Loss: 1.6117


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.74it/s]


EPOCH 51: Train acc: 94.42% Train Loss: 0.3121 Test acc: 78.64% Test Loss: 1.6273


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.70it/s]


EPOCH 52: Train acc: 94.31% Train Loss: 0.3058 Test acc: 78.63% Test Loss: 1.6860


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.82it/s]


Saved best parameters at epoch 53
EPOCH 53: Train acc: 94.43% Train Loss: 0.3023 Test acc: 79.77% Test Loss: 1.5409


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.74it/s]


EPOCH 54: Train acc: 94.65% Train Loss: 0.3089 Test acc: 79.02% Test Loss: 1.6321


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.76it/s]


EPOCH 55: Train acc: 94.44% Train Loss: 0.3018 Test acc: 78.79% Test Loss: 1.6310


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.68it/s]


EPOCH 56: Train acc: 94.48% Train Loss: 0.3108 Test acc: 79.09% Test Loss: 1.6203


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:18<00:00, 15.60it/s]


EPOCH 57: Train acc: 94.52% Train Loss: 0.3010 Test acc: 79.07% Test Loss: 1.6283


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.64it/s]


EPOCH 58: Train acc: 94.42% Train Loss: 0.3016 Test acc: 78.99% Test Loss: 1.5922


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:18<00:00, 15.61it/s]


EPOCH 59: Train acc: 94.53% Train Loss: 0.3047 Test acc: 78.67% Test Loss: 1.6605


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.81it/s]


EPOCH 60: Train acc: 94.64% Train Loss: 0.2981 Test acc: 78.43% Test Loss: 1.7191


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.74it/s]


EPOCH 61: Train acc: 94.47% Train Loss: 0.3045 Test acc: 78.62% Test Loss: 1.6607


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.71it/s]


EPOCH 62: Train acc: 94.61% Train Loss: 0.3009 Test acc: 78.93% Test Loss: 1.5672


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.71it/s]


EPOCH 63: Train acc: 94.54% Train Loss: 0.3047 Test acc: 79.13% Test Loss: 1.6193


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.66it/s]


EPOCH 64: Train acc: 94.42% Train Loss: 0.3023 Test acc: 78.38% Test Loss: 1.6780


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.85it/s]


EPOCH 65: Train acc: 94.34% Train Loss: 0.3096 Test acc: 79.12% Test Loss: 1.5856


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.69it/s]


EPOCH 66: Train acc: 94.47% Train Loss: 0.3010 Test acc: 78.90% Test Loss: 1.6244


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.65it/s]


EPOCH 67: Train acc: 94.50% Train Loss: 0.3159 Test acc: 78.69% Test Loss: 1.6377


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:18<00:00, 15.64it/s]


EPOCH 68: Train acc: 94.44% Train Loss: 0.3075 Test acc: 78.74% Test Loss: 1.6526


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.78it/s]


EPOCH 69: Train acc: 94.40% Train Loss: 0.3009 Test acc: 78.53% Test Loss: 1.6811


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.74it/s]


EPOCH 70: Train acc: 94.67% Train Loss: 0.3040 Test acc: 78.98% Test Loss: 1.5922


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.56it/s]


EPOCH 71: Train acc: 94.44% Train Loss: 0.3055 Test acc: 79.21% Test Loss: 1.5704


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.79it/s]


EPOCH 72: Train acc: 94.47% Train Loss: 0.2971 Test acc: 78.60% Test Loss: 1.6907


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.79it/s]


EPOCH 73: Train acc: 94.64% Train Loss: 0.2996 Test acc: 78.49% Test Loss: 1.6881


100%|██████████| 313/313 [00:57<00:00,  5.42it/s]
100%|██████████| 282/282 [00:17<00:00, 15.68it/s]


EPOCH 74: Train acc: 94.22% Train Loss: 0.3140 Test acc: 78.68% Test Loss: 1.6361


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.60it/s]


EPOCH 75: Train acc: 94.66% Train Loss: 0.3092 Test acc: 78.37% Test Loss: 1.6639


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:18<00:00, 15.64it/s]


EPOCH 76: Train acc: 94.11% Train Loss: 0.3082 Test acc: 78.58% Test Loss: 1.7097


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.87it/s]


EPOCH 77: Train acc: 94.66% Train Loss: 0.3043 Test acc: 79.20% Test Loss: 1.5818


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
100%|██████████| 282/282 [00:17<00:00, 15.83it/s]


EPOCH 78: Train acc: 94.53% Train Loss: 0.3027 Test acc: 78.30% Test Loss: 1.7168


100%|██████████| 313/313 [00:57<00:00,  5.43it/s]
100%|██████████| 282/282 [00:17<00:00, 15.76it/s]


EPOCH 79: Train acc: 94.04% Train Loss: 0.3226 Test acc: 78.31% Test Loss: 1.6977


100%|██████████| 313/313 [00:57<00:00,  5.45it/s]
100%|██████████| 282/282 [00:17<00:00, 15.74it/s]


EPOCH 80: Train acc: 94.33% Train Loss: 0.3110 Test acc: 78.60% Test Loss: 1.6568


100%|██████████| 313/313 [00:57<00:00,  5.44it/s]
 41%|████      | 115/282 [00:07<00:10, 16.33it/s]

In [None]:
model.load_state_dict(torch.load(config.SAVE_PATH+model_filename))

In [None]:
test_preds, test_targets, test_loss = engine.eval_fn(model, test_loader, criterion, device=device)

In [None]:
accuracy_score(test_targets, test_preds)

---

### Accuracy per class and confusion matrix

In [None]:
class_acc = dict()
matrix = np.zeros((7,7), dtype=int)

for i in range(7):
    pred_index = np.where(test_targets==i)[0]
    class_acc[class_mapper[i]] = accuracy_score(test_targets[pred_index], test_preds[pred_index])
    
    for j in test_preds[pred_index]:
        matrix[i,j]+=1 

In [None]:
class_acc

In [None]:
print(f"Global Test accuracy {accuracy_score(test_targets, test_preds)}")

sum_of_rows = matrix.sum(axis=1)
conf_mat = 100*matrix / sum_of_rows[:, np.newaxis]

df_cm = pd.DataFrame(conf_mat, index = class_acc.keys(),
                  columns = class_acc.keys())

plt.figure(figsize = (8,5))
plt.title("Confusion Matrix (Accuracy %)")
sns.heatmap(df_cm, annot=True)