In [1]:
import os

In [2]:
os.chdir("../")

In [3]:
%pwd

'/Users/arash/ML_End_to_End_Pj/end-to-end-solar-dust-detection'

In [25]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class TrainingConfig:
    root_dir: Path
    trained_model_path: Path
    updated_base_model_path: Path
    training_data: Path
    
    # Training hyperparameters
    params_epochs: int
    params_batch_size: int
    params_is_augmentation: bool
    params_image_size: list
    params_learning_rate: float
    params_classes: int

In [13]:
from solar_dust_detection.constants import *
from solar_dust_detection.utils.common import read_yaml, create_directories

In [26]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath: Path = CONFIG_FILE_PATH,
        params_filepath: Path = PARAMS_FILE_PATH,
    ):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([Path(self.config.artifacts_root)])
        

    def get_training_config(self) -> TrainingConfig:
        training_config = self.config.training

        root_dir = Path(training_config.root_dir)
        trained_model_path = Path(training_config.trained_model_path)
        updated_base_model_path = Path(self.config.base_model.updated_base_model_path)
        training_data = os.path.join(self.config.data_ingestion.unzipped_data_dir, "Detect_solar_dust")
        create_directories([root_dir])
        
        params_epochs = self.params.EPOCHS
        params_batch_size = self.params.BATCH_SIZE
        params_is_augmentation = self.params.AUGMENTATION
        params_image_size = self.params.IMAGE_SIZE
        params_learning_rate = self.params.LEARNING_RATE
        params_classes = self.params.CLASSES

        training_config = TrainingConfig(
            root_dir=root_dir,
            trained_model_path=trained_model_path,
            updated_base_model_path=updated_base_model_path,
            training_data=training_data,
            params_epochs=params_epochs,
            params_batch_size=params_batch_size,
            params_is_augmentation=params_is_augmentation,
            params_image_size=params_image_size,
            params_learning_rate=params_learning_rate,
            params_classes=params_classes,
        )

        return training_config

In [29]:
# Components update
import os
import urllib.request as request
from zipfile import ZipFile
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms, datasets, models
from solar_dust_detection import logger
import time

class MapDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.dataset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.dataset)

class Training:
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
        logger.info(f"Using device: {self.device}")

    def get_base_model(self):
        self.model = models.resnet18(weights=None)
        
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, self.config.params_classes)
        
        checkpoint = torch.load(self.config.updated_base_model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        
        # 4. Move to GPU/CPU
        self.model.to(self.device)
        
    def train_valid_generator(self):
        
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        val_transforms = transforms.Compose([
            transforms.Resize(self.config.params_image_size[:-1]), # Resize to (224, 224)
            transforms.ToTensor(),
            normalize
        ])

        if self.config.params_is_augmentation:
            train_transforms = transforms.Compose([
                transforms.Resize(self.config.params_image_size[:-1]),
                transforms.RandomRotation(40),
                transforms.RandomHorizontalFlip(),
                transforms.RandomAffine(degrees=0, translate=(0.2, 0.2), shear=0.2),
                transforms.ToTensor(),
                normalize
            ])
        else:
            train_transforms = val_transforms       

        # ImageFolder expects structure: data/class_a/img1.jpg, data/class_b/img2.jpg
        full_dataset = datasets.ImageFolder(root=self.config.training_data)
        
        val_size = int(len(full_dataset) * 0.20)
        train_size = len(full_dataset) - val_size
        
        train_subset, val_subset = random_split(full_dataset, [train_size, val_size])

        train_dataset = MapDataset(train_subset, train_transforms)
        val_dataset = MapDataset(val_subset, val_transforms)

        self.train_loader = DataLoader(
            train_dataset, 
            batch_size=self.config.params_batch_size, 
            shuffle=True, 
            num_workers=0 
        )
        
        self.valid_loader = DataLoader(
            val_dataset, 
            batch_size=self.config.params_batch_size, 
            shuffle=False,
            num_workers=0
        )

    @staticmethod
    def save_model(path: Path, model: nn.Module):
        torch.save(model.state_dict(), path)

    def train(self):
        # 1. Define Loss and Optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.model.parameters(), lr=self.config.params_learning_rate)

        print(f"Training on {self.device} with {len(self.train_loader.dataset)} samples.")

        # 2. The Training Loop 
        for epoch in range(self.config.params_epochs):
            self.model.train() 
            running_loss = 0.0
            correct = 0
            total = 0

            # --- Batch Loop ---
            for images, labels in self.train_loader:
                images, labels = images.to(self.device), labels.to(self.device)

                # Zero gradients
                optimizer.zero_grad()

                # Forward pass
                outputs = self.model(images)
                loss = criterion(outputs, labels)

                # Backward pass and optimize
                loss.backward()
                optimizer.step()

                # Metrics
                running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            # --- Epoch Metrics ---
            epoch_acc = 100 * correct / total
            logger.info(f"Epoch [{epoch+1}/{self.config.params_epochs}] "
                  f"Loss: {running_loss/len(self.train_loader):.4f} "
                  f"Acc: {epoch_acc:.2f}%")


        self.save_model(
            path=self.config.trained_model_path,
            model=self.model
        )
        logger.info(f"Model saved to {self.config.trained_model_path}")

In [30]:
# Pipeline

try:
    config = ConfigurationManager()
    training_config = config.get_training_config()
    training = Training(config=training_config)
    training.get_base_model()
    training.train_valid_generator()
    training.train()
except Exception as e:
    logger.exception(e)
    raise e

[2026-01-24 17:07:15,760: INFO: common]: YAML file: config/config.yaml loaded successfully
[2026-01-24 17:07:15,766: INFO: common]: YAML file: params.yaml loaded successfully
[2026-01-24 17:07:15,767: INFO: common]: Created directory at: artifacts
[2026-01-24 17:07:15,768: INFO: common]: Created directory at: artifacts/training
[2026-01-24 17:07:15,768: INFO: 31722983]: Using device: mps
Training on mps with 2050 samples.




[2026-01-24 17:08:13,527: INFO: 31722983]: Epoch [1/40] Loss: 0.5429 Acc: 72.59%
[2026-01-24 17:08:59,534: INFO: 31722983]: Epoch [2/40] Loss: 0.4928 Acc: 78.15%
[2026-01-24 17:09:45,355: INFO: 31722983]: Epoch [3/40] Loss: 0.4056 Acc: 82.59%
[2026-01-24 17:10:32,885: INFO: 31722983]: Epoch [4/40] Loss: 0.3734 Acc: 84.20%
[2026-01-24 17:11:19,864: INFO: 31722983]: Epoch [5/40] Loss: 0.3519 Acc: 85.71%
[2026-01-24 17:12:06,006: INFO: 31722983]: Epoch [6/40] Loss: 0.3365 Acc: 85.37%
[2026-01-24 17:12:52,556: INFO: 31722983]: Epoch [7/40] Loss: 0.2734 Acc: 88.88%
[2026-01-24 17:13:42,026: INFO: 31722983]: Epoch [8/40] Loss: 0.2982 Acc: 88.24%
[2026-01-24 17:14:28,194: INFO: 31722983]: Epoch [9/40] Loss: 0.2786 Acc: 88.34%
[2026-01-24 17:15:14,797: INFO: 31722983]: Epoch [10/40] Loss: 0.2653 Acc: 89.02%
[2026-01-24 17:16:00,875: INFO: 31722983]: Epoch [11/40] Loss: 0.2158 Acc: 92.00%
[2026-01-24 17:16:47,174: INFO: 31722983]: Epoch [12/40] Loss: 0.2234 Acc: 91.22%
[2026-01-24 17:17:33,470: