In [1]:
import os
%pwd

'/mnt/cb03386d-9344-47b1-82f9-868fbb64b4ae/python_projects/facial_expression_detection/research'

In [2]:
os.chdir('../')
%pwd

'/mnt/cb03386d-9344-47b1-82f9-868fbb64b4ae/python_projects/facial_expression_detection'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelTrainerConfig:
    root_dir: Path
    models: Path
    figures: Path
    dataset_folder: Path
    dataset_labels: Path
    model_params: dict

In [4]:
from src.detmood.constant import *
from src.detmood.utils.main_utils import create_directories, read_yaml

class ConfigurationManager:
    def __init__(
        self,
        config_file_path = CONFIG_FILE_PATH,
        params_file_path = PARAMS_FILE_PATH,
        schema_file_path = SCHEMA_FILE_PATH
    ):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)
        self.schema = read_yaml(schema_file_path)
        
        create_directories([self.config.artifacts_root])
    
    def get_model_trainer_config(self) -> ModelTrainerConfig:
        config = self.config.model_trainer
        params = self.params.model
        
        create_directories([config.models, config.figures])
        
        model_trainer_config = ModelTrainerConfig(
            root_dir=config.root_dir,
            models=config.models,
            figures=config.figures,
            dataset_folder=config.dataset_folder,
            dataset_labels=config.dataset_labels,
            model_params=params
        )
        
        return model_trainer_config

In [11]:
from src.detmood.constant.dataset_preparation import CustomImageDataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, models
from torchvision.models import EfficientNet_B0_Weights
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

class ModelTrainer:
    def __init__(self, config: ModelTrainerConfig):
        self.config = config
    
    def train(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print('Device: ', device)
        
        transform = transforms.Compose([
            transforms.Resize((
                self.config.model_params.img_in_size,
                self.config.model_params.img_in_size
            )),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        dataset = CustomImageDataset(
            self.config.dataset_labels,
            self.config.dataset_folder,
            transform=transform
        )
        
        skf = StratifiedKFold(
            n_splits=self.config.model_params.num_folds,
            shuffle=True,
            random_state=42
        )
        
        for fold, (train_idx, val_idx) in tqdm(enumerate(skf.split(dataset.data_frame, dataset.data_frame['label']))):
            print(f'Fold {fold + 1}/{self.config.model_params.num_folds}')
            
            train_subset = Subset(dataset, train_idx)
            val_subset = Subset(dataset, val_idx)
            
            train_loader = DataLoader(
                train_subset,
                batch_size=self.config.model_params.batch_size,
                shuffle=True
            )
            val_loader = DataLoader(
                val_subset,
                batch_size=self.config.model_params.batch_size,
                shuffle=False
            )
            
            model = models.efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
            model.classifier[1] = nn.Sequential(
                nn.Linear(
                    in_features=1280,
                    out_features=512
                ),
                nn.ReLU(),
                nn.Linear(
                    in_features=512,
                    out_features=self.config.model_params.num_classes
                )
            )
            model.to(device)
            
            criterion = torch.nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=self.config.model_params.lr)
            
            train_losses = []
            val_losses = []
            train_accuracies = []
            val_accuracies = []
            best_val_loss = float('inf')
            
            for epoch in tqdm(range(self.config.model_params.num_epochs)):
                model.train()
                running_loss = 0.0
                correct_train = 0
                total_train = 0
                
                for images, labels in tqdm(train_loader):
                    images, labels = images.to(device), labels.to(device)
                    
                    optimizer.zero_grad()
                    
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    loss.backward()
                    optimizer.step()

                    running_loss += loss.item()
                    
                    _, predicted = torch.max(outputs.data, 1)
                    total_train += labels.size(0)
                    correct_train += (predicted == labels).sum().item()
                    
                avg_train_loss = running_loss / len(train_loader)
                train_losses.append(avg_train_loss)
                train_accuracy = 100 * correct_train / total_train
                train_accuracies.append(train_accuracy)
                
                model.eval()
                val_loss = 0.0
                correct = 0
                total = 0
                all_preds = []
                all_labels = []
                
                with torch.no_grad():
                    print('Validation process...')
                    for images, labels in tqdm(val_loader):
                        images, labels = images.to(device), labels.to(device)
                        
                        outputs = model(images)
                        loss = criterion(outputs, labels)
                        val_loss += loss.item()
                        
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()
                        
                        all_preds.extend(predicted.cpu().numpy())
                        all_labels.extend(labels.cpu().numpy())
                
                avg_val_loss = val_loss / len(val_loader)
                val_accuracy = 100 * correct / total
                val_losses.append(avg_val_loss)
                val_accuracies.append(val_accuracy)
                
                print(f'Epoch [{epoch+1}/{self.config.model_params.num_epochs}], '
                        f'Loss: {avg_train_loss:.4f}, '
                        f'Validation Loss: {avg_val_loss:.4f}, '
                        f'Train Accuracy: {train_accuracy:.2f}%, '
                        f'Validation Accuracy: {val_accuracy:.2f}%')
                
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
            
            model_path = os.path.join(self.config.models, f'efficientnet_fold_{fold + 1}.pth')
            torch.save(model.state_dict(), model_path)
            print(f'Saved Best Model for Fold {fold + 1} at Epoch {epoch + 1}')
            
            epochs_range = range(1, self.config.model_params.num_epochs + 1)
            
            plt.figure(figsize=(12, 6))
            plt.plot(epochs_range, train_losses, label='Train Loss')
            plt.plot(epochs_range, val_losses, label='Validation Loss')
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.title(f'Train/validation Loss for Fold {fold + 1}')
            plt.legend()
            plt.savefig(os.path.join(self.config.figures, f'train_val_lossfold_{fold + 1}.png'))
            
            plt.figure(figsize=(12, 6))
            plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
            plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
            plt.xlabel('Epochs')
            plt.ylabel('Accuracy')
            plt.title(f'Train/Validation Accuracy for Fold {fold + 1}')
            plt.legend()
            plt.savefig(os.path.join(self.config.figures, f'train_val_accuracy_fold_{fold + 1}.png'))
            
            cm = confusion_matrix(all_labels, all_preds)
            plt.figure(figsize=(10, 8))
            sns.heatmap(
                cm,
                annot=True,
                fmt='d',
                cmap='Blues',
                xticklabels=range(self.config.model_params.num_classes),
                yticklabels=range(self.config.model_params.num_classes)
            )
            plt.xlabel('Predicted Labels')
            plt.ylabel('True Labels')
            plt.title(f'Confusion Matrix for Fold {fold + 1}')
            plt.savefig(os.path.join(self.config.figures, f'cm_fold_{fold + 1}.png'))
            
            print(f'Finished fold {fold + 1}/{self.config.model_params.num_folds}\n')
                
        print('Training completed.')

In [12]:
try:
    config = ConfigurationManager()
    model_trainer_config = config.get_model_trainer_config()
    model_trainer = ModelTrainer(config=model_trainer_config)
    model_trainer.train()

except Exception as e:
    raise e

[2024-10-27 12:38:06,301: INFO: main_utils: created directory at: artifacts]
[2024-10-27 12:38:06,302: INFO: main_utils: created directory at: artifacts/model_trainer/models]
[2024-10-27 12:38:06,304: INFO: main_utils: created directory at: artifacts/model_trainer/figures]
Device:  cuda


0it [00:00, ?it/s]

Fold 1/5



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Validation process...



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 68/68 [00:09<00:00,  7.19it/s]


Epoch [1/10], Loss: 1.1391, Validation Loss: 0.9929, Train Accuracy: 59.53%, Validation Accuracy: 64.73%



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Validation process...



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 68/68 [00:09<00:00,  7.38it/s]


Epoch [2/10], Loss: 0.7786, Validation Loss: 0.8889, Train Accuracy: 72.50%, Validation Accuracy: 70.03%



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
 30%|███       | 82/269 [00:28<01:05,  2.84it/s]
 20%|██        | 2/10 [03:57<15:49, 118.69s/it]
0it [03:57, ?it/s]


KeyboardInterrupt: 