In [1]:
import os
os.chdir('../')

In [2]:
%pwd

'f:\\Project_MultitaskModel'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class TrainingModelConfig:
    root_dir: Path
    trained_model_path: Path
    base_model_path: Path
    data_classification: Path
    data_segmentation: Path
    n_classes: int
    n_segment: int
    in_channels: int
    batch_size: int
    epochs: int
    learning_rate: float
    image_size: list
    augmentation: bool
    seed: int
    task_num: int
    num_workers: int

In [4]:
from src.Project_MultitaskModel.constants import *
from src.Project_MultitaskModel.utils.common import read_yaml, create_directories

In [5]:
class ConfigureManager:
    def __init__(self,
                 config_filepath: Path = CONFIG_FILE_PATH,
                 params_filepath: Path = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])
        
    def get_training_model_config(self) -> TrainingModelConfig:
        training = self.config.training_model
        prepare_base_model = self.config.prepare_base_model
        params = self.params
        create_directories([training.root_dir])
        
        training_model_config = TrainingModelConfig(
            root_dir=Path(training.root_dir),
            trained_model_path=Path(training.trained_model_path),
            base_model_path=Path(prepare_base_model.base_model_path),
            data_classification=Path(training.data_classification),
            data_segmentation=Path(training.data_segmentation),
            n_classes=params.N_CLASSES,
            n_segment=params.N_SEGMENT,
            in_channels=params.IN_CHANNELS,
            batch_size=params.BATCH_SIZE,
            epochs=params.EPOCHS,
            learning_rate=params.LEARNING_RATE,
            image_size=params.IMAGE_SIZE,
            augmentation=params.AUGMENTATION,
            seed=params.SEED,
            task_num=params.TASK_NUM,
            num_workers=params.NUM_WORKERS
        )
        
        return training_model_config

In [None]:
import os
import urllib.request as request
from zipfile import ZipFile
import torch
import torchvision.models as models
from src.models.multi_task_model import MultiTaskModelResNet
from src.data.loader_data import data_loader
from src.loss_func.combined_loss import UncertainlyLoss
from src.metrics import calculate_dice, calculate_iou, EarlyStopping

In [None]:
class TrainingModel:
    def __init__(self, config: TrainingModelConfig):
        self.config = config
    
    def load_base_model(self):
        model = MultiTaskModelResNet(
            n_classes=self.config.n_classes,
            n_segment=self.config.n_segment,
            in_channels=self.config.in_channels,
            pretrained=False
        )
        state_dict = torch.load(self.config.base_model_path)
        model.load_state_dict(state_dict)
        return model

    def loader_data(self):

        train_class_path = os.path.join(self.config.data_classification, 'train')
        train_seg_path = os.path.join(self.config.data_segmentation, 'train')

        train_loader = data_loader(
            data_classification_path=train_class_path,
            data_segmentation_path=train_seg_path,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            augmentation=self.config.augmentation,
            seed=self.config.seed,
            img_size=self.config.image_size
        )

        return train_loader
    
    def train_model(self):
        model = self.load_base_model()
        train_loader = self.loader_data()
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = model.to(device)
        
        criterion = UncertainlyLoss(task_num=self.config.task_num)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.config.learning_rate)
        
        for epoch in range(self.config.epochs):
            model.train()
            running_loss = 0.0
            running_dice = 0.0
            running_iou = 0.0
            for i, (images, masks, labels) in enumerate(train_loader):
                images = images.to(device)
                masks = masks.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                outputs_class, outputs_seg = model(images)
                
                loss = criterion(outputs_seg, masks, outputs_class, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                dice = calculate_dice(outputs_seg, masks)
                iou = calculate_iou(outputs_seg, masks)
                running_dice += dice.item()
                running_iou += iou.item()

            epoch_loss = running_loss / len(train_loader)
            epoch_dice = running_dice / len(train_loader)
            epoch_iou = running_iou / len(train_loader)
            print(f'Epoch [{epoch+1}/{self.config.epochs}], Loss: {epoch_loss:.4f}, Dice: {epoch_dice:.4f}, IoU: {epoch_iou:.4f}')

        self.save_model(model, self.config.trained_model_path)
        print(f'Model saved to {self.config.trained_model_path}')

    def save_model(self, model, path: Path):
        torch.save(model.state_dict(), path)
            

: 

In [None]:
try:
    config_manager = ConfigureManager()
    training_model_config = config_manager.get_training_model_config()
    
    trainer = TrainingModel(config=training_model_config)
    trainer.train_model()
except Exception as e:
    print(e)

[2025-12-19 11:56:14,946: INFO: common: yaml file: configs\config.yaml loaded successfully]
[2025-12-19 11:56:14,951: INFO: common: yaml file: params.yaml loaded successfully]
[2025-12-19 11:56:14,952: INFO: common: created directory at: artifacts]
[2025-12-19 11:56:14,954: INFO: common: created directory at: artifacts/training_model]
Data Augmentation is Disabled
