In [None]:
!pip install torch torchvision scikit-learn tqdm matplotlib



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from tqdm import tqdm
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [None]:
DATA_DIR = "dataset_split/mouth"

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
train_ds = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_transforms)
val_ds   = datasets.ImageFolder(f"{DATA_DIR}/val", transform=val_test_transforms)
test_ds  = datasets.ImageFolder(f"{DATA_DIR}/test", transform=val_test_transforms)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_ds, batch_size=32, shuffle=False)

print("Classes:", train_ds.classes)
print("Train:", len(train_ds), "Val:", len(val_ds), "Test:", len(test_ds))

Classes: ['no_yawn', 'yawn']
Train: 1013 Val: 216 Test: 219


In [None]:
model = models.efficientnet_b0(weights="IMAGENET1K_V1")

In [None]:
model.classifier[1] = nn.Linear(
    model.classifier[1].in_features, 1
)

model = model.to(device)

In [None]:
for param in model.features.parameters():
    param.requires_grad = False

In [None]:
criterion = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(
    model.classifier.parameters(),
    lr=1e-4
)

In [None]:
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0

    for images, labels in tqdm(loader):
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
def evaluate(model, loader):
    model.eval()
    y_true, y_pred, y_prob = [], [], []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)

            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)

            y_prob.extend(probs)
            y_pred.extend(preds)
            y_true.extend(labels.numpy())

    return np.array(y_true), np.array(y_pred), np.array(y_prob)

In [None]:
EPOCHS = 20

for epoch in range(EPOCHS):
    loss = train_one_epoch(model, train_loader)
    y_true, y_pred, _ = evaluate(model, val_loader)

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("Train Loss:", round(loss, 4))
    print(classification_report(y_true, y_pred, target_names=train_ds.classes))

100%|██████████| 32/32 [02:05<00:00,  3.93s/it]



Epoch 1/20
Train Loss: 0.6922
              precision    recall  f1-score   support

     no_yawn       0.59      0.47      0.53       108
        yawn       0.56      0.68      0.61       108

    accuracy                           0.57       216
   macro avg       0.58      0.57      0.57       216
weighted avg       0.58      0.57      0.57       216



100%|██████████| 32/32 [01:51<00:00,  3.49s/it]



Epoch 2/20
Train Loss: 0.6775
              precision    recall  f1-score   support

     no_yawn       0.66      0.59      0.62       108
        yawn       0.63      0.69      0.66       108

    accuracy                           0.64       216
   macro avg       0.65      0.64      0.64       216
weighted avg       0.65      0.64      0.64       216



100%|██████████| 32/32 [01:52<00:00,  3.51s/it]



Epoch 3/20
Train Loss: 0.6658
              precision    recall  f1-score   support

     no_yawn       0.73      0.55      0.62       108
        yawn       0.64      0.80      0.71       108

    accuracy                           0.67       216
   macro avg       0.68      0.67      0.67       216
weighted avg       0.68      0.67      0.67       216



100%|██████████| 32/32 [01:51<00:00,  3.49s/it]



Epoch 4/20
Train Loss: 0.6515
              precision    recall  f1-score   support

     no_yawn       0.79      0.56      0.66       108
        yawn       0.66      0.85      0.74       108

    accuracy                           0.71       216
   macro avg       0.73      0.71      0.70       216
weighted avg       0.73      0.71      0.70       216



100%|██████████| 32/32 [02:05<00:00,  3.92s/it]



Epoch 5/20
Train Loss: 0.6431
              precision    recall  f1-score   support

     no_yawn       0.75      0.63      0.68       108
        yawn       0.68      0.79      0.73       108

    accuracy                           0.71       216
   macro avg       0.71      0.71      0.71       216
weighted avg       0.71      0.71      0.71       216



100%|██████████| 32/32 [01:55<00:00,  3.61s/it]



Epoch 6/20
Train Loss: 0.6284
              precision    recall  f1-score   support

     no_yawn       0.77      0.69      0.73       108
        yawn       0.72      0.80      0.75       108

    accuracy                           0.74       216
   macro avg       0.74      0.74      0.74       216
weighted avg       0.74      0.74      0.74       216



100%|██████████| 32/32 [01:57<00:00,  3.68s/it]



Epoch 7/20
Train Loss: 0.6201
              precision    recall  f1-score   support

     no_yawn       0.79      0.69      0.73       108
        yawn       0.72      0.81      0.77       108

    accuracy                           0.75       216
   macro avg       0.75      0.75      0.75       216
weighted avg       0.75      0.75      0.75       216



100%|██████████| 32/32 [01:54<00:00,  3.57s/it]



Epoch 8/20
Train Loss: 0.6095
              precision    recall  f1-score   support

     no_yawn       0.80      0.65      0.71       108
        yawn       0.70      0.83      0.76       108

    accuracy                           0.74       216
   macro avg       0.75      0.74      0.74       216
weighted avg       0.75      0.74      0.74       216



100%|██████████| 32/32 [02:05<00:00,  3.91s/it]



Epoch 9/20
Train Loss: 0.6028
              precision    recall  f1-score   support

     no_yawn       0.78      0.77      0.78       108
        yawn       0.77      0.79      0.78       108

    accuracy                           0.78       216
   macro avg       0.78      0.78      0.78       216
weighted avg       0.78      0.78      0.78       216



100%|██████████| 32/32 [01:54<00:00,  3.57s/it]



Epoch 10/20
Train Loss: 0.5992
              precision    recall  f1-score   support

     no_yawn       0.80      0.74      0.77       108
        yawn       0.76      0.81      0.79       108

    accuracy                           0.78       216
   macro avg       0.78      0.78      0.78       216
weighted avg       0.78      0.78      0.78       216



100%|██████████| 32/32 [01:54<00:00,  3.57s/it]



Epoch 11/20
Train Loss: 0.588
              precision    recall  f1-score   support

     no_yawn       0.79      0.70      0.75       108
        yawn       0.73      0.81      0.77       108

    accuracy                           0.76       216
   macro avg       0.76      0.76      0.76       216
weighted avg       0.76      0.76      0.76       216



100%|██████████| 32/32 [01:54<00:00,  3.57s/it]



Epoch 12/20
Train Loss: 0.5795
              precision    recall  f1-score   support

     no_yawn       0.78      0.70      0.74       108
        yawn       0.73      0.81      0.77       108

    accuracy                           0.75       216
   macro avg       0.76      0.75      0.75       216
weighted avg       0.76      0.75      0.75       216



100%|██████████| 32/32 [01:54<00:00,  3.59s/it]



Epoch 13/20
Train Loss: 0.5771
              precision    recall  f1-score   support

     no_yawn       0.81      0.73      0.77       108
        yawn       0.75      0.82      0.79       108

    accuracy                           0.78       216
   macro avg       0.78      0.78      0.78       216
weighted avg       0.78      0.78      0.78       216



100%|██████████| 32/32 [02:06<00:00,  3.94s/it]



Epoch 14/20
Train Loss: 0.5612
              precision    recall  f1-score   support

     no_yawn       0.80      0.81      0.80       108
        yawn       0.80      0.80      0.80       108

    accuracy                           0.80       216
   macro avg       0.80      0.80      0.80       216
weighted avg       0.80      0.80      0.80       216



100%|██████████| 32/32 [01:54<00:00,  3.58s/it]



Epoch 15/20
Train Loss: 0.5715
              precision    recall  f1-score   support

     no_yawn       0.83      0.72      0.77       108
        yawn       0.75      0.85      0.80       108

    accuracy                           0.79       216
   macro avg       0.79      0.79      0.79       216
weighted avg       0.79      0.79      0.79       216



100%|██████████| 32/32 [01:53<00:00,  3.56s/it]



Epoch 16/20
Train Loss: 0.5594
              precision    recall  f1-score   support

     no_yawn       0.81      0.81      0.81       108
        yawn       0.81      0.81      0.81       108

    accuracy                           0.81       216
   macro avg       0.81      0.81      0.81       216
weighted avg       0.81      0.81      0.81       216



100%|██████████| 32/32 [01:54<00:00,  3.58s/it]



Epoch 17/20
Train Loss: 0.5473
              precision    recall  f1-score   support

     no_yawn       0.83      0.73      0.78       108
        yawn       0.76      0.85      0.80       108

    accuracy                           0.79       216
   macro avg       0.80      0.79      0.79       216
weighted avg       0.80      0.79      0.79       216



100%|██████████| 32/32 [02:04<00:00,  3.90s/it]



Epoch 18/20
Train Loss: 0.5523
              precision    recall  f1-score   support

     no_yawn       0.82      0.82      0.82       108
        yawn       0.82      0.82      0.82       108

    accuracy                           0.82       216
   macro avg       0.82      0.82      0.82       216
weighted avg       0.82      0.82      0.82       216



100%|██████████| 32/32 [01:53<00:00,  3.55s/it]



Epoch 19/20
Train Loss: 0.5402
              precision    recall  f1-score   support

     no_yawn       0.83      0.80      0.82       108
        yawn       0.81      0.84      0.82       108

    accuracy                           0.82       216
   macro avg       0.82      0.82      0.82       216
weighted avg       0.82      0.82      0.82       216



100%|██████████| 32/32 [01:54<00:00,  3.58s/it]



Epoch 20/20
Train Loss: 0.5408
              precision    recall  f1-score   support

     no_yawn       0.83      0.81      0.82       108
        yawn       0.82      0.83      0.83       108

    accuracy                           0.82       216
   macro avg       0.82      0.82      0.82       216
weighted avg       0.82      0.82      0.82       216



In [None]:
for param in model.features[-2:].parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=1e-5)


In [None]:
for epoch in range(8):
    loss = train_one_epoch(model, train_loader)
    y_true, y_pred, _ = evaluate(model, val_loader)

    print(f"\nFine-Tune Epoch {epoch+1}")
    print("Loss:", round(loss, 4))
    print(classification_report(y_true, y_pred))

100%|██████████| 32/32 [01:58<00:00,  3.72s/it]



Fine-Tune Epoch 1
Loss: 0.5258
              precision    recall  f1-score   support

           0       0.83      0.83      0.83       108
           1       0.83      0.82      0.83       108

    accuracy                           0.83       216
   macro avg       0.83      0.83      0.83       216
weighted avg       0.83      0.83      0.83       216



100%|██████████| 32/32 [01:58<00:00,  3.71s/it]



Fine-Tune Epoch 2
Loss: 0.5098
              precision    recall  f1-score   support

           0       0.83      0.84      0.83       108
           1       0.84      0.82      0.83       108

    accuracy                           0.83       216
   macro avg       0.83      0.83      0.83       216
weighted avg       0.83      0.83      0.83       216



100%|██████████| 32/32 [02:08<00:00,  4.03s/it]



Fine-Tune Epoch 3
Loss: 0.5217
              precision    recall  f1-score   support

           0       0.83      0.83      0.83       108
           1       0.83      0.83      0.83       108

    accuracy                           0.83       216
   macro avg       0.83      0.83      0.83       216
weighted avg       0.83      0.83      0.83       216



100%|██████████| 32/32 [01:58<00:00,  3.69s/it]



Fine-Tune Epoch 4
Loss: 0.4929
              precision    recall  f1-score   support

           0       0.85      0.83      0.84       108
           1       0.84      0.85      0.84       108

    accuracy                           0.84       216
   macro avg       0.84      0.84      0.84       216
weighted avg       0.84      0.84      0.84       216



100%|██████████| 32/32 [01:58<00:00,  3.71s/it]



Fine-Tune Epoch 5
Loss: 0.499
              precision    recall  f1-score   support

           0       0.85      0.81      0.83       108
           1       0.82      0.86      0.84       108

    accuracy                           0.83       216
   macro avg       0.83      0.83      0.83       216
weighted avg       0.83      0.83      0.83       216



100%|██████████| 32/32 [01:59<00:00,  3.72s/it]



Fine-Tune Epoch 6
Loss: 0.4775
              precision    recall  f1-score   support

           0       0.85      0.86      0.85       108
           1       0.86      0.84      0.85       108

    accuracy                           0.85       216
   macro avg       0.85      0.85      0.85       216
weighted avg       0.85      0.85      0.85       216



100%|██████████| 32/32 [02:09<00:00,  4.03s/it]



Fine-Tune Epoch 7
Loss: 0.4799
              precision    recall  f1-score   support

           0       0.85      0.87      0.86       108
           1       0.87      0.85      0.86       108

    accuracy                           0.86       216
   macro avg       0.86      0.86      0.86       216
weighted avg       0.86      0.86      0.86       216



100%|██████████| 32/32 [01:58<00:00,  3.69s/it]



Fine-Tune Epoch 8
Loss: 0.4581
              precision    recall  f1-score   support

           0       0.85      0.87      0.86       108
           1       0.87      0.85      0.86       108

    accuracy                           0.86       216
   macro avg       0.86      0.86      0.86       216
weighted avg       0.86      0.86      0.86       216



In [None]:
y_true, y_pred, y_prob = evaluate(model, test_loader)

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=train_ds.classes))

print("ROC-AUC:", roc_auc_score(y_true, y_prob))


In [None]:
torch.save(
    model.state_dict(),
    "mouth_state_efficientnet_b0.pth"
)

print("Mouth-state model saved!")
