In [1]:
import shutil
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

from emotions_utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CORE_PATH = "../"
ORIG_PATH = f"{CORE_PATH}/emotions"
SAVE_LOGS_PATH = f"{CORE_PATH}/missclassified"
SAVE_MODELS_PATH = f"{CORE_PATH}/models"

In [3]:
data = ImageFolder(ORIG_PATH)
data

Dataset ImageFolder
    Number of datapoints: 496
    Root location: ..//emotions

In [4]:
class_to_idx = data.class_to_idx
idx_to_class = {v: k for k, v in class_to_idx.items()}

class_to_idx

{'anger': 0,
 'contempt': 1,
 'disgust': 2,
 'fear': 3,
 'joy': 4,
 'sadness': 5,
 'wonder': 6}

In [5]:
torch.manual_seed(0)

train_size = 0.75
val_size = 0.1
test_size = 1 - train_size - val_size

train_data, val_data, test_data = torch.utils.data.random_split(
    data, [train_size, val_size, test_size]
)

len(train_data), len(val_data), len(test_data)

(373, 49, 74)

In [6]:
BATCH_SIZE = 16

train_dataset = EmotionsDataset(train_data, train_transforms)
val_dataset = EmotionsDataset(val_data, inf_transforms)
test_dataset = EmotionsDataset(test_data, inf_transforms)

train_loader = get_loader(train_dataset, BATCH_SIZE)
val_loader = DataLoader(val_dataset, BATCH_SIZE)
test_loader = DataLoader(test_dataset, BATCH_SIZE)

In [7]:
LR = 0.001
EPOCHS = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()

In [8]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

In [10]:
model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

for name, param in model.named_parameters():
    param.requires_grad = False

classifier = nn.Sequential(nn.Linear(model.heads[0].in_features, 512),
                           nn.ReLU(),
                           nn.Dropout(),
                           nn.Linear(512, len(class_to_idx)))
model.heads = classifier
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [11]:
loss_train, loss_val = train_model(
    model,
    train_loader,
    val_loader,
    EPOCHS,
    criterion,
    optimizer,
    device,
    "best_vit16_100.pt",
    early_stopping=EarlyStopping(3)
)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
100%|██████████| 24/24 [00:11<00:00,  2.08it/s]


EPOCH: 1, train_loss: 1.851833812033523, val_loss: 1.7545764665214383
train_f1: 0.24010416666666667, val_f1: 0.21875



100%|██████████| 24/24 [00:11<00:00,  2.18it/s]


EPOCH: 2, train_loss: 1.6280841210572394, val_loss: 1.650626785901128
train_f1: 0.36354166666666665, val_f1: 0.25



100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


EPOCH: 3, train_loss: 1.5134193500948336, val_loss: 1.6558296096568206
train_f1: 0.4145833333333333, val_f1: 0.21875



100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


EPOCH: 4, train_loss: 1.473401653223319, val_loss: 1.5839101149111379
train_f1: 0.46458333333333335, val_f1: 0.234375



100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


EPOCH: 5, train_loss: 1.3614791197048115, val_loss: 1.6293716819918886
train_f1: 0.47031249999999997, val_f1: 0.234375



100%|██████████| 24/24 [00:09<00:00,  2.44it/s]


EPOCH: 6, train_loss: 1.2959421479989632, val_loss: 1.6697646160514987
train_f1: 0.5140625, val_f1: 0.25



100%|██████████| 24/24 [00:10<00:00,  2.29it/s]


EPOCH: 7, train_loss: 1.2216868931103009, val_loss: 1.506736658057388
train_f1: 0.5234375, val_f1: 0.3125



100%|██████████| 24/24 [00:10<00:00,  2.36it/s]


EPOCH: 8, train_loss: 1.3069061932550998, val_loss: 1.6221353083240742
train_f1: 0.525, val_f1: 0.21875



100%|██████████| 24/24 [00:10<00:00,  2.23it/s]


EPOCH: 9, train_loss: 1.154951261770949, val_loss: 1.6984441523649254
train_f1: 0.5854166666666667, val_f1: 0.25



100%|██████████| 24/24 [00:10<00:00,  2.19it/s]


EPOCH: 10, train_loss: 1.0316877353926446, val_loss: 1.7239181265539052
train_f1: 0.6401041666666667, val_f1: 0.296875



100%|██████████| 24/24 [00:10<00:00,  2.24it/s]


EPOCH: 11, train_loss: 1.1250755227602838, val_loss: 1.7271901734021244
train_f1: 0.5984375000000001, val_f1: 0.28125

Early Stopping!


In [17]:
torch.save(model.state_dict(), f"{SAVE_MODELS_PATH}/vit16_100_train.pt")
shutil.copyfile("best_vit16_100.pt", f"{SAVE_MODELS_PATH}/best_vit16_100.pt")

'..//models/best_vit16_100.pt'

In [45]:
best = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
classifier = nn.Sequential(nn.Linear(best.heads[0].in_features, 512),
                           nn.ReLU(),
                           nn.Dropout(),
                           nn.Linear(512, len(class_to_idx)))
best.heads = classifier
best.load_state_dict(torch.load(f"{SAVE_MODELS_PATH}/best_vit16_100.pt"))
best.eval()
best = best.to(device)

In [12]:
test_model(model, test_loader, device)

0.31


In [10]:
test_model(best, test_loader, device)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


0.27749999999999997


In [11]:
not_train_indices = val_data.indices + test_data.indices
inf_images = [data.imgs[x][0] for x in not_train_indices]
inf_labels = [data.imgs[x][1] for x in not_train_indices]

preds = inference_model(best, inf_images, device)
get_metrics_report(inf_labels, preds)

{'Accuracy': 0.3333333333333333,
 'Precision_macro': 0.29018174895236365,
 'Precision_micro': 0.3333333333333333,
 'Recall_macro': 0.2956856677157429,
 'Recall_micro': 0.3333333333333333,
 'ROC_AUC': {0: 0.6199213630406291,
  1: 0.6072874493927126,
  2: 0.49557522123893805,
  3: 0.486013986013986,
  4: 0.7271844660194176,
  5: 0.6532653061224489,
  6: 0.5513051305130514}}

In [13]:
not_train_indices = val_data.indices + test_data.indices
inf_images = [data.imgs[x][0] for x in not_train_indices]
inf_labels = [data.imgs[x][1] for x in not_train_indices]

preds = inference_model(model, inf_images, device)
get_metrics_report(inf_labels, preds)

{'Accuracy': 0.3333333333333333,
 'Precision_macro': 0.3358543096872616,
 'Precision_micro': 0.3333333333333333,
 'Recall_macro': 0.3214825024599461,
 'Recall_micro': 0.3333333333333333,
 'ROC_AUC': {0: 0.6245085190039318,
  1: 0.659919028340081,
  2: 0.5190265486725665,
  3: 0.5786713286713286,
  4: 0.716747572815534,
  5: 0.6038775510204082,
  6: 0.5335283528352834}}

In [12]:
get_classification_report(inf_labels, preds, idx_to_class)

anger emotion
Overall images: 14
Correctly predicted 4/14

contempt emotion
Overall images: 19
Correctly predicted 7/19

disgust emotion
Overall images: 10
Correctly predicted 0/10

fear emotion
Overall images: 13
Correctly predicted 2/13

joy emotion
Overall images: 20
Correctly predicted 12/20

sadness emotion
Overall images: 25
Correctly predicted 12/25

wonder emotion
Overall images: 22
Correctly predicted 4/22



In [None]:
get_mistaken_images_report(inf_images, inf_labels, preds, "vit16_100", idx_to_class, SAVE_LOGS_PATH)