Compute dataset statistics, mean and std per RGB channel.
Save results to artifacts/stats.json for later use in normalization.

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from pathlib import Path
import json

def main():
    data_dir = Path("data/raw")

    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # размер под модель
        transforms.ToTensor()
    ])

    dataset = datasets.ImageFolder(data_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)  # <-- важно num_workers=0 на Windows

    mean = 0.0
    std = 0.0
    total_images = 0

    for images, _ in loader:
        batch_samples = images.size(0)  # количество изображений в батче
        images = images.view(batch_samples, images.size(1), -1)  # B x C x (H*W)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images += batch_samples

    mean /= total_images
    std /= total_images

    mean = mean.tolist()
    std = std.tolist()

    print(f"\nDataset statistics:")
    print(f"Mean: {mean}")
    print(f"Std:  {std}")

    Path("artifacts").mkdir(parents=True, exist_ok=True)

    stats_path = Path("artifacts/stats.json")
    with stats_path.open("w") as f:
        json.dump({"mean": mean, "std": std}, f)

    print(f"\nStatistics saved to {stats_path}")

if __name__ == "__main__":
    main()


Config for train.py (hyperparameters, out_dir, seed)

In [None]:
#!/usr/bin/env python3

from dataclasses import dataclass, field
from typing import List, Dict, Any
import torch

@dataclass
class DataConfig:
   
    data_root: str = "data/raw"
    img_size: int = 224
    batch_size: int = 16
    num_workers: int = 2
    train_ratio: float = 0.7
    val_ratio: float = 0.15
    test_ratio: float = 0.15
    
    augmentation: bool = True
    color_jitter: float = 0.15
    random_rotate: int = 10

@dataclass  
class ModelConfig:
   
    model_name: str = "resnet50"
    num_classes: int = 3
    pretrained: bool = True
    freeze_backbone: bool = True
    unfreeze_epoch: int = 5
    
    # Hyperparameters
    learning_rate: float = 0.001
    weight_decay: float = 0.01
    
    @classmethod
    def get_learning_rate_space(cls):
        return [0.1, 0.01, 0.001, 0.0001, 0.00001]
    
    @classmethod
    def get_weight_decay_space(cls):
        return [0.1, 0.01, 0.001, 0.0001, 0.0]
    
    @classmethod
    def get_unfreeze_epoch_space(cls):
        return [3, 5, 7, 10]

@dataclass
class TrainConfig:
    epochs: int = 15
    early_stopping_patience: int = 5
    
    learning_rate: float = 0.001
    weight_decay: float = 0.01
    momentum: float = 0.9
    
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    out_dir: str = "artifacts"
    save_checkpoints: bool = True
    log_interval: int = 1
    
    @classmethod
    def get_batch_size_space(cls):
        return [8, 16, 32, 64]
    
    @classmethod  
    def get_epochs_space(cls):
        return [10, 15, 20, 25]

@dataclass
class HyperparameterSearchConfig:
    search_type: str = "grid"
    n_trials: int = 8
    
    learning_rates: List[float] = field(default_factory=lambda: [0.001, 0.0001])
    # batch_sizes: List[int] = field(default_factory=lambda: [8, 16, 32])
    # weight_decays: List[float] = field(default_factory=lambda: [0.1, 0.01, 0.001, 0.0001])
    epochs_list: List[int] = field(default_factory=lambda: [10, 15])
    
    tuning_epochs: int = 10
    patience: int = 3

@dataclass
class ExperimentConfig:
    data: DataConfig
    model: ModelConfig
    train: TrainConfig
    
    def to_dict(self):
        return {
            'data': self.data.__dict__,
            'model': self.model.__dict__,
            'train': self.train.__dict__
        }
    
    @classmethod
    def from_dict(cls, config_dict):
        return cls(
            data=DataConfig(**config_dict['data']),
            model=ModelConfig(**config_dict['model']),
            train=TrainConfig(**config_dict['train'])
        )

RESNET50_EXPERIMENT = ExperimentConfig(
    data=DataConfig(batch_size=16, img_size=224),
    model=ModelConfig(model_name="resnet50", freeze_backbone=True, unfreeze_epoch=5),
    train=TrainConfig(epochs=15, learning_rate=0.001, seed=42)
)

EFFICIENTNET_EXPERIMENT = ExperimentConfig(
    data=DataConfig(batch_size=32, img_size=224),
    model=ModelConfig(model_name="efficientnet_b0", freeze_backbone=True, unfreeze_epoch=3),
    train=TrainConfig(epochs=15, learning_rate=0.001, seed=42)
)

EXPERIMENT_CONFIGS = {
    'resnet50': RESNET50_EXPERIMENT,
    'efficientnet_b0': EFFICIENTNET_EXPERIMENT,
}

analysis.py charts

In [None]:
#!/usr/bin/env python3
import torch
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from pathlib import Path
import os

def load_model_results():
    print("UPLOADING MODEL RESULTS")
    
    models_info = []
    model_names = ['resnet50', 'mobilenetv3_large_100']
    
    for model_name in model_names:
        try:
            checkpoint_path = f"artifacts/models/best_{model_name}.pth"
            
            # Исправление для PyTorch 2.6 с weights_only=True
            try:
                # Сначала пробуем с weights_only=True
                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
            except:
                # Если не получается, пробуем с weights_only=False (менее безопасно)
                print(f"Warning: Using weights_only=False for {model_name}. Load from trusted source only.")
                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
            
            # Извлекаем метрики из чекпоинта
            val_accuracy = checkpoint.get('val_accuracy', 0)
            if not val_accuracy:
                # Пробуем другие возможные ключи
                val_accuracy = checkpoint.get('best_acc', checkpoint.get('accuracy', 0))
            
            epochs_trained = checkpoint.get('epoch', checkpoint.get('epochs', 0))
            
            # Подсчет параметров
            model_state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', {}))
            if model_state_dict:
                parameters = sum(p.numel() for p in model_state_dict.values())
            else:
                parameters = 0
            
            models_info.append({
                'Model': model_name,
                'Family': 'ResNet' if 'resnet' in model_name else 'MobileNet',
                'Test Accuracy': float(val_accuracy),
                'Epochs Trained': int(epochs_trained),
                'Parameters': int(parameters)
            })
            print(f"{model_name}: Accuracy = {val_accuracy:.4f}, Parameters = {parameters:,}")
            
        except Exception as e:
            print(f"{model_name}: {e}")
    
    return pd.DataFrame(models_info)

def load_results_from_csv():
    """Альтернативный способ загрузки результатов из CSV файлов"""
    print("\nTRYING TO LOAD FROM CSV FILES...")
    
    models_info = []
    csv_files = {
        'resnet50': 'resnet50_tuning_results.csv',
        'mobilenetv3_large_100': 'mobilenetv3_large_100_tuning_results.csv'
    }
    
    for model_name, csv_file in csv_files.items():
        try:
            if os.path.exists(csv_file):
                df = pd.read_csv(csv_file)
                if not df.empty:
                    best_result = df.loc[df['val_acc'].idxmax()]
                    
                    models_info.append({
                        'Model': model_name,
                        'Family': 'ResNet' if 'resnet' in model_name else 'MobileNet',
                        'Test Accuracy': float(best_result['val_acc']),
                        'Epochs Trained': int(best_result.get('epochs', best_result.get('best_epoch', 0))),
                        'Parameters': int(estimate_parameters(model_name))
                    })
                    print(f"{model_name} from CSV: Accuracy = {best_result['val_acc']:.4f}")
        except Exception as e:
            print(f"Error loading {csv_file}: {e}")
    
    return pd.DataFrame(models_info)

def estimate_parameters(model_name):
    """Оценка количества параметров для известных архитектур"""
    param_estimates = {
        'resnet50': 25_557_032,
        'mobilenetv3_large_100': 5_483_032
    }
    return param_estimates.get(model_name, 0)

def plot_comparison(df):
    """Create comparison plots."""
    print("\nCREATING COMPARISON CHARTS")
    
    # Создаем директорию для графиков если её нет
    os.makedirs('artifacts/plots', exist_ok=True)
    
    plt.style.use('default')
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # График 1: Сравнение точности
    bars = axes[0].bar(df['Model'], df['Test Accuracy'], 
                      color=['#1f77b4', '#ff7f0e'], alpha=0.7)
    axes[0].set_title('Model Accuracy Comparison', fontsize=14, fontweight='bold')
    axes[0].set_ylabel('Accuracy')
    axes[0].set_ylim(0, 1)
    axes[0].grid(axis='y', alpha=0.3)
    
    for bar in bars:
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{height:.3f}', ha='center', va='bottom', fontweight='bold')
    
    # График 2: Количество параметров
    if df['Parameters'].sum() > 0:
        bars = axes[1].bar(df['Model'], df['Parameters'] / 1e6, 
                          color=['#2ca02c', '#d62728'], alpha=0.7)
        axes[1].set_title('Number of Parameters (Millions)', fontsize=14, fontweight='bold')
        axes[1].set_ylabel('Million Parameters')
        axes[1].grid(axis='y', alpha=0.3)
        
        for bar in bars:
            height = bar.get_height()
            axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                        f'{height:.1f}M', ha='center', va='bottom', fontweight='bold')
    else:
        axes[1].text(0.5, 0.5, 'Parameter data\nnot available', 
                    ha='center', va='center', transform=axes[1].transAxes, fontsize=12)
        axes[1].set_title('Number of Parameters', fontsize=14, fontweight='bold')
    
    # График 3: Сравнение по семействам
    family_acc = df.groupby('Family')['Test Accuracy'].mean()
    if len(family_acc) > 1:
        bars = axes[2].bar(family_acc.index, family_acc.values,
                          color=['#9467bd', '#8c564b'], alpha=0.7)
        axes[2].set_title('Accuracy by Model Family', fontsize=14, fontweight='bold')
        axes[2].set_ylabel('Average Accuracy')
        axes[2].set_ylim(0, 1)
        axes[2].grid(axis='y', alpha=0.3)
        
        for bar in bars:
            height = bar.get_height()
            axes[2].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{height:.3f}', ha='center', va='bottom', fontweight='bold')
    else:
        axes[2].text(0.5, 0.5, 'Only one model family\navailable', 
                    ha='center', va='center', transform=axes[2].transAxes, fontsize=12)
        axes[2].set_title('Accuracy by Model Family', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('artifacts/plots/model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("Graphs saved: artifacts/plots/model_comparison.png")

def generate_report(df):
    print("\nANALYTICAL REPORT")
    print("=" * 50)
    
    if df.empty:
        print("No data available for analysis")
        return
    
    best_model = df.loc[df['Test Accuracy'].idxmax()]
    
    print("BASIC METRICS:")
    print(f"Best Model: {best_model['Model']}")
    print(f"  - Accuracy: {best_model['Test Accuracy']:.4f} ({best_model['Test Accuracy']*100:.2f}%)")
    print(f"  - Parameters: {best_model['Parameters']:,}")
    print(f"  - Family: {best_model['Family']}")
    print(f"  - Epochs Trained: {best_model['Epochs Trained']}")
    
    if len(df) > 1:
        fastest_model = df.loc[df['Parameters'].idxmin()]
        print(f"\nLightest Model: {fastest_model['Model']}")
        print(f"  - Parameters: {fastest_model['Parameters']:,}")
        print(f"  - Accuracy: {fastest_model['Test Accuracy']:.4f}")
    
    print(f"\nFAMILY COMPARISON:")
    family_stats = df.groupby('Family').agg({
        'Test Accuracy': ['mean', 'max', 'count'],
        'Parameters': 'mean'
    }).round(4)
    print(family_stats)
    
    print(f"\nKEY INSIGHTS:")
    if len(df) > 1:
        accuracy_diff = float(df['Test Accuracy'].max() - df['Test Accuracy'].min())
        if df['Parameters'].sum() > 0:
            param_ratio = float(df['Parameters'].max() / df['Parameters'].min())
            print(f"Accuracy difference: {accuracy_diff:.4f}")
            print(f"Parameter ratio: {param_ratio:.1f}x")
            
            if accuracy_diff < 0.05:
                print("RECOMMENDATION: Choose the lighter model (accuracy difference is negligible)")
            else:
                print("RECOMMENDATION: Choose the more accurate model")
        else:
            print(f"Accuracy difference: {accuracy_diff:.4f}")
            print("RECOMMENDATION: Choose model with highest accuracy")
    else:
        print(f"Only one model available: {best_model['Model']}")
        print(f"Final accuracy: {best_model['Test Accuracy']:.4f}")

def save_results(df):
    """Save analysis results to JSON"""
    os.makedirs('artifacts', exist_ok=True)
    
    if df.empty:
        results = {'error': 'No data available for analysis'}
    else:
        best_model = df.loc[df['Test Accuracy'].idxmax()]
        
        # Конвертируем pandas типы в нативные Python типы
        best_model_dict = {}
        for key, value in best_model.items():
            if pd.isna(value):
                best_model_dict[key] = None
            elif isinstance(value, (np.integer, pd.Int64Dtype)):
                best_model_dict[key] = int(value)
            elif isinstance(value, (np.floating, pd.Float64Dtype)):
                best_model_dict[key] = float(value)
            else:
                best_model_dict[key] = value
        
        # Конвертируем весь DataFrame
        comparison_list = []
        for _, row in df.iterrows():
            row_dict = {}
            for key, value in row.items():
                if pd.isna(value):
                    row_dict[key] = None
                elif isinstance(value, (np.integer, pd.Int64Dtype)):
                    row_dict[key] = int(value)
                elif isinstance(value, (np.floating, pd.Float64Dtype)):
                    row_dict[key] = float(value)
                else:
                    row_dict[key] = value
            comparison_list.append(row_dict)
        
        results = {
            'best_model': best_model_dict,
            'comparison': comparison_list,
            'summary': {
                'accuracy_range': [
                    float(df['Test Accuracy'].min()), 
                    float(df['Test Accuracy'].max())
                ],
                'parameter_range': [
                    int(df['Parameters'].min()), 
                    int(df['Parameters'].max())
                ],
                'models_tested': int(len(df)),
                'best_accuracy': float(best_model['Test Accuracy'])
            }
        }
    
    with open('artifacts/analysis_results.json', 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"\nResults saved: artifacts/analysis_results.json")

def main():
    """Main analysis function."""
    print("ANALYSIS OF EXPERIMENTAL RESULTS")
    print("=" * 50)
    
    # Пробуем загрузить из моделей
    df = load_model_results()
    
    # Если не получилось, пробуем из CSV
    if df.empty:
        print("\nTrying alternative data sources...")
        df = load_results_from_csv()
    
    if df.empty:
        print("\nNo data available for analysis")
        print("Please check:")
        print("1. Model files exist in artifacts/models/")
        print("2. CSV files with results exist")
        return
    
    # Создаем графики
    plot_comparison(df)
    
    # Генерируем отчет
    generate_report(df)
    
    # Сохраняем результаты
    save_results(df)
    
    print(f"\nAnalysis completed successfully!")
    print(f"Check files in artifacts/ directory")

if __name__ == "__main__":
    main()

test_predictions.py

In [None]:
import torch
import torch.nn as nn
import timm
from PIL import Image
import torchvision.transforms as transforms
import os

def load_model_fixed(model_name, num_classes=3):
    """Fixed model loading for PyTorch 2.6"""
    try:
        if model_name == 'mobilenetv3_large_100':
            model = timm.create_model('mobilenetv3_large_100', pretrained=False, num_classes=num_classes)
        elif model_name == 'resnet50':
            model = timm.create_model('resnet50', pretrained=False, num_classes=num_classes)
        else:
            print(f"Unknown model: {model_name}")
            return None
            
        checkpoint_path = f"artifacts/models/best_{model_name}.pth"
        checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
        
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        elif 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        model.eval()
        print(f"✓ {model_name} loaded successfully")
        return model
        
    except Exception as e:
        print(f"✗ Model loading error {model_name}: {e}")
        return None

def predict_image(model, image_path, class_names, transform):
    """Prediction for single image"""
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0)
        
        with torch.no_grad():
            outputs = model(image)
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            predicted_class = torch.argmax(probabilities).item()
            confidence = probabilities[predicted_class].item()
            
        return class_names[predicted_class], confidence
    except Exception as e:
        return f"Error: {e}", 0.0

def main():
    print("TESTING MODEL PREDICTIONS")
    
    class_names = ['minivan', 'sedan', 'wagon']
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_images = [
        'data/raw/minivan/kisspng-2016-nissan-quest-car-2013-nissan-quest-2011-nissa-5ae78661c06196.926277781525122657788.jpg',
        'data/raw/sedan/d2e5f48218cc6572a9388998fa185769.jpg', 
        'data/raw/wagon/0d1901c8711742e1cb1769c805ad9e8b.jpg'
    ]
    
    print("Test Images:")
    for img in test_images:
        print(f"  {os.path.basename(img)}")
    print(f"Classes: {class_names}")
    
    models = {}
    for model_name in ['resnet50', 'mobilenetv3_large_100']:
        print(f"\nLoading: {model_name}")
        model = load_model_fixed(model_name, num_classes=len(class_names))
        if model:
            models[model_name] = model
    
    if not models:
        print("Couldn't load any models")
        return
    
    print(f"\n{'='*50}")
    print("PREDICTION RESULTS:")
    print(f"{'='*50}")
    
    for model_name, model in models.items():
        print(f"\n{model_name.upper()} PREDICTIONS:")
        print("-" * 40)
        
        for img_path in test_images:
            if os.path.exists(img_path):
                predicted_class, confidence = predict_image(model, img_path, class_names, transform)
                print(f"  {os.path.basename(img_path)}:")
                print(f"    Predicted: {predicted_class}")
                print(f"    Confidence: {confidence:.2%}")
            else:
                print(f"  {img_path}: File not found")
        
        print("-" * 40)

if __name__ == "__main__":
    main()

train.py

In [None]:
#!/usr/bin/env python3
"""
Main training script with robust error handling for model downloads.
Uses ResNet50 as first model and MobileNetV3 as second model.
"""

import os
import random
import json
import logging
import sys
from datetime import datetime
from pathlib import Path
import warnings

# Создаем директорию для логов и конфигурируем logger
Path("artifacts").mkdir(exist_ok=True)
log_file = "artifacts/training.log"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers.clear()  # очищаем обработчики, если что-то было установлено раньше

formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")

file_handler = logging.FileHandler(log_file, mode="w", encoding="utf-8")
file_handler.setFormatter(formatter)

stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.addHandler(stream_handler)

warnings.filterwarnings('ignore')
os.environ['TIMM_DOWNLOAD_TIMEOUT'] = '30'
os.environ['TIMM_DOWNLOAD_RETRY'] = '2'

import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
import argparse       
import pandas as pd    
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import timm
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

from config import DataConfig, HyperparameterSearchConfig, ModelConfig, TrainConfig, ExperimentConfig, EXPERIMENT_CONFIGS


class ModelTrainer:
    """Model trainer with robust error handling."""
    
    def __init__(self, experiment_config: ExperimentConfig):
        stats_path = "artifacts/stats.json"
        if not Path(stats_path).exists():
            raise FileNotFoundError(f"{stats_path} не найден. Сначала запустите compute_stats.py")
    
        with open(stats_path, "r") as f:
            stats = json.load(f)

        self.DATA_MEAN = stats["mean"]
        self.DATA_STD = stats["std"]

        self.config = experiment_config
        self.data_cfg = experiment_config.data
        self.model_cfg = experiment_config.model
        self.train_cfg = experiment_config.train
        self.set_seed()
        
    def set_seed(self):
        """Fix all random generators for reproducibility."""
        random.seed(self.train_cfg.seed)
        np.random.seed(self.train_cfg.seed)
        torch.manual_seed(self.train_cfg.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.train_cfg.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    def seed_worker(self, worker_id):
        """Seed worker for DataLoader reproducibility."""
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)
        
    def setup_data_directories(self):
        """Create necessary directories."""
        data_dirs = [
            "data/raw/minivan",
            "data/raw/sedan", 
            "data/raw/wagon",
            "artifacts/models",
            "artifacts/plots"
        ]
        
        for dir_path in data_dirs:
            Path(dir_path).mkdir(parents=True, exist_ok=True)
        
    def check_data_exists(self, data_root):
        if not Path(data_root).exists():
            print(f"Directory {data_root} does not exist")
            logger.info(f"Directory {data_root} does not exist")
            return False
            
        subdirs = [d for d in Path(data_root).iterdir() if d.is_dir()]
        if not subdirs:
            print(f"No class directories in {data_root}!")
            logger.info(f"No class directories in {data_root}!")
            return False
            
        print(f"Data found in: {data_root}")
        logger.info(f"Data found in: {data_root}")
        total_images = 0
        
        for subdir in subdirs:
            image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.JPG', '.JPEG', '.PNG', '.WEBP'}
            images = [f for f in subdir.iterdir() if f.suffix.lower() in image_extensions and f.is_file()]
            
            print(f" {subdir.name}: {len(images)} images")
            logger.info(f" {subdir.name}: {len(images)} images")
            total_images += len(images)
            
        print(f"Total images: {total_images}")
        logger.info(f"Total images: {total_images}")
        
        if total_images < 30:
            print("Less than 30 images total - may affect model performance")
            logger.info("Less than 30 images total - may affect model performance")
            
        return total_images > 0
        
    def get_transforms(self):
        if self.data_cfg.augmentation:
            train_t = transforms.Compose([
                transforms.RandomResizedCrop(self.data_cfg.img_size),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.15, 0.15, 0.15, 0.05),
                transforms.ToTensor(),
                transforms.Normalize(self.DATA_MEAN, self.DATA_STD)
            ])
        else:
            train_t = transforms.Compose([
                transforms.Resize((self.data_cfg.img_size, self.data_cfg.img_size)),
                transforms.ToTensor(),
                transforms.Normalize(self.DATA_MEAN, self.DATA_STD)
            ])
            
        val_t = transforms.Compose([
            transforms.Resize((self.data_cfg.img_size, self.data_cfg.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(self.DATA_MEAN, self.DATA_STD)
        ])
        
        return train_t, val_t

    def create_data_loaders(self):
        """Create DataLoaders with full reproducibility."""
        train_t, val_t = self.get_transforms()
        
        full_dataset = ImageFolder(self.data_cfg.data_root)
        indices = list(range(len(full_dataset)))
        labels = [s[1] for s in full_dataset.samples]
        
        train_idx, test_idx = train_test_split(
            indices, 
            test_size=self.data_cfg.test_ratio, 
            stratify=labels, 
            random_state=self.train_cfg.seed
        )
        train_idx, val_idx = train_test_split(
            train_idx, 
            test_size=self.data_cfg.val_ratio/(1-self.data_cfg.test_ratio),
            stratify=[labels[i] for i in train_idx], 
            random_state=self.train_cfg.seed
        )
        
        train_ds = ImageFolder(self.data_cfg.data_root, transform=train_t)
        val_ds = ImageFolder(self.data_cfg.data_root, transform=val_t)
        test_ds = ImageFolder(self.data_cfg.data_root, transform=val_t)
        
        train_subset = Subset(train_ds, train_idx)
        val_subset = Subset(val_ds, val_idx)
        test_subset = Subset(test_ds, test_idx)
        
        generator = torch.Generator()
        generator.manual_seed(self.train_cfg.seed)
        
        train_loader = DataLoader(
            train_subset, 
            batch_size=self.data_cfg.batch_size, 
            shuffle=True,
            num_workers=self.data_cfg.num_workers,
            worker_init_fn=self.seed_worker,
            generator=generator,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_subset, 
            batch_size=self.data_cfg.batch_size, 
            shuffle=False,
            num_workers=self.data_cfg.num_workers,
            pin_memory=True
        )
        
        test_loader = DataLoader(
            test_subset, 
            batch_size=self.data_cfg.batch_size,
            shuffle=False,
            num_workers=self.data_cfg.num_workers,
            pin_memory=True
        )
        
        self.classes = full_dataset.classes
        print(f"Classes: {self.classes}")
        logger.info(f"Classes: {self.classes}")
        print(f"Dataset split: train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
        logger.info(f"Dataset split: train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
        
        return train_loader, val_loader, test_loader

    def create_model(self):
        """Create model with robust error handling for downloads."""
        print(f"Creating model: {self.model_cfg.model_name}")
        logger.info(f"Creating model: {self.model_cfg.model_name}")
        print(f"Pretrained: {self.model_cfg.pretrained}")
        logger.info(f"Pretrained: {self.model_cfg.pretrained}")
        
        try:
            model = timm.create_model(
                self.model_cfg.model_name, 
                pretrained=self.model_cfg.pretrained, 
                num_classes=self.model_cfg.num_classes
            )
            print(f"Successfully created {self.model_cfg.model_name}")
            logger.info(f"Successfully created {self.model_cfg.model_name}")
            
        except Exception as e:
            print(f"Failed to download {self.model_cfg.model_name}: {e}")
            logger.info(f"Failed to download {self.model_cfg.model_name}: {e}")
            print("Falling back to randomly initialized model...")
            logger.info("Falling back to randomly initialized model...")
            
            model = timm.create_model(
                self.model_cfg.model_name, 
                pretrained=False,  
                num_classes=self.model_cfg.num_classes
            )
            print(f"Created {self.model_cfg.model_name} with random initialization")
            logger.info(f"Created {self.model_cfg.model_name} with random initialization")
        
        if self.model_cfg.freeze_backbone:
            frozen_count = 0
            trainable_count = 0
            
            for name, param in model.named_parameters():
                if any(keyword in name for keyword in ['head', 'fc', 'classifier']):
                    param.requires_grad = True
                    trainable_count += 1
                else:
                    param.requires_grad = False
                    frozen_count += 1
                    
            print(f"Freezing: {frozen_count} frozen, {trainable_count} trainable")
            logger.info(f"Freezing: {frozen_count} frozen, {trainable_count} trainable")
            print(f"Unfreeze at epoch: {self.model_cfg.unfreeze_epoch}")
            logger.info(f"Unfreeze at epoch: {self.model_cfg.unfreeze_epoch}")
                    
        return model

    def unfreeze_model(self, model):
        """Unfreeze all model parameters."""
        print("Unfreezing all parameters")
        logger.info("Unfreezing all parameters")
        for param in model.parameters():
            param.requires_grad = True

    def train_epoch(self, model, loader, criterion, optimizer, device):
        """Single training epoch."""
        model.train()
        running_loss = 0.0
        all_preds = []
        all_targets = []
        
        for images, labels in loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_targets.extend(labels.cpu().numpy())
            
        avg_loss = running_loss / len(loader.dataset)
        acc = accuracy_score(all_targets, all_preds)
        
        return avg_loss, acc

    def validate(self, model, loader, criterion, device):
        """Validate model."""
        model.eval()
        running_loss = 0.0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for images, labels in loader:
                images = images.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                outputs = model(images)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * images.size(0)
                preds = outputs.argmax(dim=1).cpu().numpy()
                all_preds.extend(preds)
                all_targets.extend(labels.cpu().numpy())
                
        avg_loss = running_loss / len(loader.dataset)
        acc = accuracy_score(all_targets, all_preds)
        cm = confusion_matrix(all_targets, all_preds)
        report = classification_report(all_targets, all_preds, 
                                     target_names=self.classes, 
                                     zero_division=0)
        
        return avg_loss, acc, cm, report

    def train(self):
        print(f"\nTraining {self.model_cfg.model_name}")
        logger.info(f"\nTraining {self.model_cfg.model_name}")
        
        train_loader, val_loader, test_loader = self.create_data_loaders()
        
        model = self.create_model()
        model = model.to(self.train_cfg.device)
        
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Parameters: {trainable_params:,} trainable / {total_params:,} total")
        logger.info(f"Parameters: {trainable_params:,} trainable / {total_params:,} total")
        
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=self.model_cfg.learning_rate,
            weight_decay=self.model_cfg.weight_decay
        )
        scheduler = CosineAnnealingLR(optimizer, T_max=self.train_cfg.epochs)
        
        history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'learning_rates': []
        }
        best_val_acc = 0.0
        
        print(f"\nTraining for {self.train_cfg.epochs} epochs on {self.train_cfg.device}")
        logger.info(f"\nTraining for {self.train_cfg.epochs} epochs on {self.train_cfg.device}")
        
        for epoch in range(self.train_cfg.epochs):
            if epoch == self.model_cfg.unfreeze_epoch and self.model_cfg.freeze_backbone:
                print(f"\nEpoch {epoch}: unfreezing backbone")
                logger.info(f"\nEpoch {epoch}: unfreezing backbone")
                self.unfreeze_model(model)
                optimizer = AdamW(
                    model.parameters(),
                    lr=self.model_cfg.learning_rate/10,
                    weight_decay=self.model_cfg.weight_decay
                )
                scheduler = CosineAnnealingLR(optimizer, T_max=self.train_cfg.epochs - epoch)
            
            train_loss, train_acc = self.train_epoch(
                model, train_loader, criterion, optimizer, self.train_cfg.device
            )
            
            val_loss, val_acc, cm, report = self.validate(
                model, val_loader, criterion, self.train_cfg.device
            )
            
            scheduler.step()
            
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)
            history['learning_rates'].append(optimizer.param_groups[0]['lr'])
            
            if (epoch + 1) % self.train_cfg.log_interval == 0:
                print(f"Epoch {epoch+1}/{self.train_cfg.epochs}: "
                      f"Train Loss: {train_loss:.4f}, accuracy: {train_acc:.4f} | "
                      f"Val Loss: {val_loss:.4f}, accuracy: {val_acc:.4f}")
                logger.info(f"Epoch {epoch+1}/{self.train_cfg.epochs}: "
                            f"Train Loss: {train_loss:.4f}, accuracy: {train_acc:.4f} | "
                            f"Val Loss: {val_loss:.4f}, accuracy: {val_acc:.4f}")
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                self.save_model(model, epoch, val_acc)
                if (epoch + 1) % self.train_cfg.log_interval == 0:
                    print(f"New best model Val accuracy: {val_acc:.4f}")
                    logger.info(f"New best model Val accuracy: {val_acc:.4f}")
        
        final_metrics = self.final_evaluation(model, test_loader, criterion)
        self.plot_results(history, cm)
        
        return history, final_metrics

    def save_model(self, model, epoch, accuracy):
        """Save model with experiment info."""
        Path(self.train_cfg.out_dir).mkdir(parents=True, exist_ok=True)
        Path(f"{self.train_cfg.out_dir}/models").mkdir(exist_ok=True)
        Path(f"{self.train_cfg.out_dir}/plots").mkdir(exist_ok=True)
        
        model_path = f"{self.train_cfg.out_dir}/models/best_{self.model_cfg.model_name}.pth"
        torch.save({
            'model_state_dict': model.state_dict(),
            'classes': self.classes,
            'config': self.config.to_dict(),
            'seed': self.train_cfg.seed,
            'val_accuracy': accuracy,
            'epoch': epoch
        }, model_path)
        
        with open(f"{self.train_cfg.out_dir}/classes.txt", 'w', encoding='utf-8') as f:
            f.write('\n'.join(self.classes))
        
        print(f"Model saved: {model_path}")
        logger.info(f"Model saved: {model_path}")

    def final_evaluation(self, model, test_loader, criterion):
        """Final evaluation on test set."""
        test_loss, test_acc, test_cm, test_report = self.validate(
            model, test_loader, criterion, self.train_cfg.device
        )
        
        print(f"\nFinal test result:")
        logger.info(f"\nFinal test result:")
        print(f"   Test loss: {test_loss:.4f}")
        logger.info(f"   Test loss: {test_loss:.4f}")
        print(f"   Test accuracy: {test_acc:.4f}")
        logger.info(f"   Test accuracy: {test_acc:.4f}")
        print("\nClassification report:")
        logger.info("\nClassification report:")
        print(test_report)
        logger.info(test_report)
        
        return {
            'test_loss': test_loss,
            'test_acc': test_acc,
            'confusion_matrix': test_cm,
            'report': test_report
        }

    def plot_results(self, history, cm):
        """Plot training results and confusion matrix with best epoch annotation."""
        Path(f"{self.train_cfg.out_dir}/plots").mkdir(parents=True, exist_ok=True)

        best_epoch = int(np.argmax(history['val_acc']))
        best_val_acc = history['val_acc'][best_epoch]
        best_lr = history['learning_rates'][best_epoch] if 'learning_rates' in history else None

        fig, axes = plt.subplots(1, 3, figsize=(20, 6))

        axes[0].plot(history['train_loss'], label='Train loss', linewidth=2, color='tab:blue')
        axes[0].plot(history['val_loss'], label='Validation loss', linewidth=2, color='tab:orange')
        axes[0].axvline(best_epoch, color='gray', linestyle='--', alpha=0.7)
        axes[0].set_title('Graph of the loss function (loss)', fontsize=12)
        axes[0].set_xlabel('Epoch', fontsize=10)
        axes[0].set_ylabel('Value Loss', fontsize=10)
        axes[0].legend(loc='best')
        axes[0].grid(True, alpha=0.3)

        axes[1].plot(history['train_acc'], label='Train accuracy', linewidth=2, color='tab:green')
        axes[1].plot(history['val_acc'], label='Validation accuracy', linewidth=2, color='tab:red')
        axes[1].axvline(best_epoch, color='gray', linestyle='--', alpha=0.7)
        axes[1].annotate(
            f"Best epoch: {best_epoch + 1}\nVal acc = {best_val_acc:.4f}\nLR = {best_lr:.6f}",
            xy=(best_epoch, best_val_acc),
            xytext=(best_epoch + 0.5, best_val_acc - 0.05),
            arrowprops=dict(arrowstyle='->', color='gray'),
            fontsize=9,
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8)
        )
        axes[1].set_title('Graph accuracy', fontsize=12)
        axes[1].set_xlabel('Epoch', fontsize=10)
        axes[1].set_ylabel('The proportion of correct answers', fontsize=10)
        axes[1].legend(loc='best')
        axes[1].grid(True, alpha=0.3)

        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[2],
                    xticklabels=self.classes, yticklabels=self.classes)
        axes[2].set_xlabel('Predicted class', fontsize=10)
        axes[2].set_ylabel('The true class', fontsize=10)
        axes[2].set_title('Confusion Matrix', fontsize=12)

        summary_text = (
            f"Best epoch: {best_epoch + 1} | "
            f"Val acc = {best_val_acc:.4f}"
            + (f" | Learning rate = {best_lr:.6f}" if best_lr is not None else "")
        )
        fig.suptitle(summary_text, fontsize=13, fontweight='bold', y=1.02)

        plt.tight_layout()
        plot_path = f"{self.train_cfg.out_dir}/plots/{self.model_cfg.model_name}_training.png"
        plt.savefig(plot_path, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Plots (loss/acc/confusion) saved: {plot_path}")
        logger.info(f"Plots (loss/acc/confusion) saved: {plot_path}")

    @staticmethod
    def export_onnx(model_name, num_classes, img_size=224):
        """Export model to ONNX format."""
        device = torch.device("cpu")
        model_path = f"artifacts/models/best_{model_name}.pth"
        if not os.path.exists(model_path):
            print(f"Model {model_path} not found for ONNX export")
            logger.info(f"Model {model_path} not found for ONNX export")
            return False
            
        try:
            checkpoint = torch.load(model_path, map_location=device)
            model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval().to(device)
            
            dummy_input = torch.randn(1, 3, img_size, img_size, device=device)
            onnx_path = f"artifacts/best_{model_name}.onnx"
            
            torch.onnx.export(
                model, dummy_input, onnx_path,
                input_names=['input'],
                output_names=['output'],
                dynamic_axes={
                    'input': {0: 'batch_size'},
                    'output': {0: 'batch_size'}
                },
                opset_version=12
            )
            print(f"ONNX model exported: {onnx_path}")
            logger.info(f"ONNX model exported: {onnx_path}")
            return True
        except Exception as e:
            print(f"ONNX export failed for {model_name}: {e}")
            logger.info(f"ONNX export failed for {model_name}: {e}")
            return False


def get_alternative_model_config():
    """Get configuration for alternative model (MobileNetV3)."""
    return ExperimentConfig(
        data=DataConfig(batch_size=32, img_size=224),
        model=ModelConfig(
            model_name="mobilenetv3_large_100",  
            num_classes=3,
            pretrained=True,
            freeze_backbone=True,
            unfreeze_epoch=3,
            learning_rate=0.001,
            weight_decay=0.01
        ),
        train=TrainConfig(epochs=15, seed=42)
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--tune", action="store_true", help="Run hyperparameter tuning")
    args = parser.parse_args()

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print("\nCar classification training")
    logger.info("\nCar classification training")
    print("Using ResNet50 + MobileNetV3 as models")
    logger.info("Using ResNet50 + MobileNetV3 as models")

    trainer = ModelTrainer(EXPERIMENT_CONFIGS['resnet50'])
    trainer.setup_data_directories()

    if not trainer.check_data_exists("data/raw"):
        print("\nPlease add your images to:")
        logger.info("\nPlease add your images to:")
        print("   data/raw/minivan/  (30+ images)")
        logger.info("   data/raw/minivan/  (30+ images)")
        print("   data/raw/sedan/    (30+ images)")
        logger.info("   data/raw/sedan/    (30+ images)")
        print("   data/raw/wagon/    (30+ images)")
        logger.info("   data/raw/wagon/    (30+ images)")
        sys.exit(1)

    models_to_try = [
        ('resnet50', EXPERIMENT_CONFIGS['resnet50']),
        ('mobilenetv3_large_100', get_alternative_model_config()),
    ]

    if args.tune:
        hpcfg = HyperparameterSearchConfig()
        for model_name, exp_cfg in models_to_try:
            print(f"\nHyperparameter tunning for {model_name.upper()}")
            logger.info(f"\nHyperparameter tunning for {model_name.upper()}")
            results = []

            for lr in hpcfg.learning_rates:
                for ep in hpcfg.epochs_list:
                    print(f"\nTraining {model_name} with lr={lr}, epochs={ep}")
                    logger.info(f"\nTraining {model_name} with lr={lr}, epochs={ep}")

                    exp_cfg.model.learning_rate = lr
                    exp_cfg.train.epochs = ep

                    trainer = ModelTrainer(exp_cfg)
                    try:
                        history, metrics = trainer.train()
                        val_acc = metrics['test_acc']
                        val_cm = metrics['confusion_matrix']

                        cm_path = f"artifacts/plots/{model_name}_lr{lr}_ep{ep}_cm.png"
                        plt.figure(figsize=(5, 5))
                        sns.heatmap(
                            val_cm,
                            annot=True,
                            fmt='d',
                            cmap='Blues',
                            xticklabels=trainer.classes,
                            yticklabels=trainer.classes
                        )
                        plt.title(f"{model_name} | lr={lr}, epochs={ep}")
                        plt.xlabel("Predicted")
                        plt.ylabel("True")
                        plt.tight_layout()
                        plt.savefig(cm_path, dpi=300, bbox_inches='tight')
                        plt.close()
                        print(f"Confusion matrix saved: {cm_path}")
                        logger.info(f"Confusion matrix saved: {cm_path}")

                    except Exception as e:
                        print(f"Training failed for {model_name} (lr={lr}, ep={ep}): {e}")
                        logger.info(f"Training failed for {model_name} (lr={lr}, ep={ep}): {e}")
                        val_acc = 0.0

                    results.append({
                        'model': model_name,
                        'lr': lr,
                        'epochs': ep,
                        'val_acc': val_acc
                    })

            df = pd.DataFrame(results)
            csv_path = f"artifacts/{model_name}_tuning_results.csv"
            df.to_csv(csv_path, index=False)
            print(f"Saved results: {csv_path}")
            logger.info(f"Saved results: {csv_path}")

            best_row = df.loc[df['val_acc'].idxmax()]
            best_lr = best_row['lr']
            best_ep = int(best_row['epochs'])
            print(f"\nBest for {model_name}: lr={best_lr}, epochs={best_ep}, acc={best_row['val_acc']:.4f}")
            logger.info(f"\nBest for {model_name}: lr={best_lr}, epochs={best_ep}, acc={best_row['val_acc']:.4f}")

            exp_cfg.model.learning_rate = best_lr
            exp_cfg.train.epochs = best_ep
            print(f"Retraining {model_name} with best params...")
            logger.info(f"Retraining {model_name} with best params...")
            trainer = ModelTrainer(exp_cfg)
            trainer.train()

    else:
        results = {}
        successful_models = []

        for model_name, exp_config in models_to_try:
            print(f"Processing model: {model_name.upper()}")
            logger.info(f"Processing model: {model_name.upper()}")

            try:
                trainer = ModelTrainer(exp_config)
                history, metrics = trainer.train()

                results[model_name] = {
                    'history': history,
                    'metrics': metrics
                }
                successful_models.append(model_name)

                export_success = ModelTrainer.export_onnx(model_name, exp_config.model.num_classes)
                if export_success:
                    print(f"{model_name} - Training and export completed")
                    logger.info(f"{model_name} - Training and export completed")
                else:
                    print(f"{model_name} - Training completed but ONNX export failed")
                    logger.info(f"{model_name} - Training completed but ONNX export failed")

            except Exception as e:
                print(f"{model_name} training failed: {e}")
                logger.info(f"{model_name} training failed: {e}")
                if model_name == 'resnet50':
                    print("First model failed. Stopping execution.")
                    logger.info("First model failed. Stopping execution.")
                    sys.exit(1)
                else:
                    print("continuing...")
                    logger.info("continuing...")
                    continue

        print("Summary")
        logger.info("Summary")

        if successful_models:
            print("Successfully trained models:")
            logger.info("Successfully trained models:")
            for model_name in successful_models:
                test_acc = results[model_name]['metrics']['test_acc']
                print(f"{model_name}: Test accuracy = {test_acc:.4f}")
                logger.info(f"{model_name}: Test accuracy = {test_acc:.4f}")

            if len(successful_models) > 1:
                best_model = max(successful_models, key=lambda x: results[x]['metrics']['test_acc'])
                best_acc = results[best_model]['metrics']['test_acc']
                print(f"\nBest model: {best_model} with accuracy {best_acc:.4f}")
                logger.info(f"\nBest model: {best_model} with accuracy {best_acc:.4f}")
        else:
            print("No models were successfully trained")
            logger.info("No models were successfully trained")

        print(f"\nResults saved in: artifacts/")
        logger.info(f"\nResults saved in: artifacts/")
        print("Training completed successfully!")
        logger.info("Training completed successfully!")


if __name__ == "__main__":
    main()
