In [2]:
import os
os.chdir("../")

In [3]:
import torch
from pathlib import Path
from dataclasses import dataclass
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from src.vision_Transformer.constants import *
from src.vision_Transformer.utils.common import read_yaml, create_directories

In [4]:
@dataclass(frozen = True)
class ModelTrainerConfig:
    root_dir : Path
    model_ckpt : str
    train_accuracy : Path
    train_loss : Path

    batch_size : int
    epochs : int 
    learning_rate : float
    patch_size : int
    num_classes : int
    image_size : int 
    channels : int
    embed_dim : int
    num_heads: int
    depth : int
    mlp_dim : int
    drop_rate : float
    weight_decay : float


In [5]:
class ConfigurationManager:
    def __init__(self, config_file_path = CONFIG_FILE_PATH ,params_file_path = PARAMS_FILE_PATH):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_model_trainer_config(self) -> ModelTrainerConfig:

        config = self.config.model_trainer
        params = self.params.TrainingArguments

        create_directories([config.root_dir])

        model_trainer_config = ModelTrainerConfig(
            root_dir= config.root_dir,
            model_ckpt = config.model_ckpt,
            train_accuracy = config.train_accuracy,
            train_loss = config.train_loss,

            batch_size= params.BATCH_SIZE,
            epochs = params.EPOCHS,
            learning_rate = params.LEARNING_RATE,
            patch_size = params.PATCH_SIZE,
            num_classes = params.NUM_CLASSES,
            image_size = params.IMAGE_SIZE,
            channels = params.CHANNELS,
            embed_dim  = params.EMBED_DIM,
            num_heads = params.NUM_HEADS,
            depth = params.DEPTH,
            mlp_dim = params.MLP_DIM,
            drop_rate = params.DROP_RATE,
            weight_decay  = params.WEIGHT_DECAY
        )
        return model_trainer_config

In [6]:
from torch.utils.data import DataLoader

In [7]:
from src.vision_Transformer.logging import logger

In [None]:
class Model_trainer:
    def __init__(self, config: ConfigurationManager , ):
        self.config = config
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        train_dataset = ""
        test_dataset = ""

        self.train_loader = DataLoader(train_dataset , batch_size = config.batch_size ,shuffle = True ,pin_memory= True)

        self.test_loader =  DataLoader(test_dataset , batch_size = config.batch_size ,shuffle = False ,pin_memory= True)

        #-----------------------------------
        def data_augmentation(self):
            self.after_transforms = transforms.Compose([
                transforms.RandomCrop(32 , padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.2 ,contrast= 0.2, saturation=0.2 , hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5]*3 , std = [0.5]*3)
            ])
    
    def transformed_dataset(self):
        transformed_train_dataset = datasets.CIFAR10(
            root = self.config.dataset_dir,
            train = True,
            download= False,
            transform= self.after_transforms,
        )
        logger.info(f"Train Dataset Transformed Successfully")
        print(f"Train Dataset Transformed Successfully")

        transformed_test_dataset = datasets.CIFAR10(
            root = self.config.dataset_dir,
            train = False,
            download= False,
            transform= self.after_transforms,
        )
        logger.info(f"Test Dataset Transformed Successfully")
        print(f"Test Dataset Transformed Successfully")

        return transformed_train_dataset , transformed_test_dataset
    # ---------------------------------------------------------------------------


    def model_trainer(self):
        pass

    def train(self):
        pass
        