In [1]:
import os
os.chdir("../")

In [2]:
import torch
from pathlib import Path
from dataclasses import dataclass
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from src.vision_Transformer.constants import *
from src.vision_Transformer.utils.common import read_yaml, create_directories

In [3]:
@dataclass(frozen = True)
class ModelTrainerConfig:
    root_dir : Path
    trained_model : str
    train_accuracy : Path
    train_loss : Path
    data_dir : Path


    batch_size : int
    epochs : int 
    learning_rate : float
    patch_size : int
    num_classes : int
    image_size : int 
    channels : int
    embed_dim : int
    num_heads: int
    depth : int
    mlp_dim : int
    dropout_rate : float
    weight_decay : float


In [4]:
class ConfigurationManager:
    def __init__(self, config_file_path = CONFIG_FILE_PATH ,params_file_path = PARAMS_FILE_PATH):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)

        create_directories([self.config.artifacts_root])

    def get_model_trainer_config(self) -> ModelTrainerConfig:

        config = self.config.model_trainer
        params = self.params.TrainingArguments

        create_directories([config.root_dir])

        model_trainer_config = ModelTrainerConfig(
            root_dir= config.root_dir,
            trained_model = config.trained_model,
            train_accuracy = config.train_accuracy,
            train_loss = config.train_loss,
            data_dir = config.data_dir,

            batch_size= params.BATCH_SIZE,
            epochs = params.EPOCHS,
            learning_rate = params.LEARNING_RATE,
            patch_size = params.PATCH_SIZE,
            num_classes = params.NUM_CLASSES,
            image_size = params.IMAGE_SIZE,
            channels = params.CHANNELS,
            embed_dim  = params.EMBED_DIM,
            num_heads = params.NUM_HEADS,
            depth = params.DEPTH,
            mlp_dim = params.MLP_DIM,
            dropout_rate = params.DROPOUT_RATE,
            weight_decay  = params.WEIGHT_DECAY
        )
        return model_trainer_config

In [5]:
from torch.utils.data import DataLoader

In [6]:
from src.vision_Transformer.logging import logger

In [7]:
class DataTransformation:
    def __init__(self , config : ModelTrainerConfig):
        self.config = config

    def data_augmentation(self):
        self.after_transforms = transforms.Compose([
            transforms.RandomCrop(32 , padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2 ,contrast= 0.2, saturation=0.2 , hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3 , std = [0.5]*3)
        ])
    
    def transformed_dataset(self):
        transformed_train_dataset = datasets.CIFAR10(
            root = self.config.data_dir,
            train = True,
            download= False,
            transform= self.after_transforms,
        )
        logger.info(f"Train Dataset Transformed Successfully")
        print(f"Train Dataset Transformed Successfully")

        transformed_test_dataset = datasets.CIFAR10(
            root = self.config.data_dir,
            train = False,
            download= False,
            transform= self.after_transforms,
        )
        logger.info(f"Test Dataset Transformed Successfully")
        print(f"Test Dataset Transformed Successfully")

        return transformed_train_dataset , transformed_test_dataset

In [8]:
config = ConfigurationManager()
model_trainer_config = config.get_model_trainer_config()

data_transfromation = DataTransformation(model_trainer_config)
data_transfromation.data_augmentation()

train_dataset , test_dataset = data_transfromation.transformed_dataset()

[2025-08-08 15:53:16,354 : INFO : common  : yaml file config\config.yaml was read succesfully]
[2025-08-08 15:53:16,362 : INFO : common  : yaml file params.yaml was read succesfully]
[2025-08-08 15:53:16,364 : INFO : common  : Created directory at : artifacts]
[2025-08-08 15:53:16,364 : INFO : common  : Created directory at : artifacts/model]
[2025-08-08 15:53:17,000 : INFO : 4111442926  : Train Dataset Transformed Successfully]
Train Dataset Transformed Successfully
[2025-08-08 15:53:17,492 : INFO : 4111442926  : Test Dataset Transformed Successfully]
Test Dataset Transformed Successfully


In [9]:
from src.vision_Transformer.Components.ViT_Component.Vision_Transformer_Class import Vision_Transformer_Class


In [21]:
class Model_trainer:
    def __init__(self, config: ModelTrainerConfig , train_dataset , test_dataset):
        self.config = config
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.train_loader = DataLoader(train_dataset , batch_size = self.config.batch_size ,shuffle = True ,pin_memory= True)

        self.test_loader =  DataLoader(test_dataset , batch_size = self.config.batch_size ,shuffle = False ,pin_memory= True)

        self.model = Vision_Transformer_Class(
            image_size = self.config.image_size,
            patch_size = self.config.patch_size,
            in_channels = self.config.channels,
            num_classes = self.config.num_classes,
            embed_dim = self.config.embed_dim,
            num_heads = self.config.num_heads,  # 8
            depth = self.config.depth,      # 6
            mlp_dim = self.config.mlp_dim,
            dropout_rate = self.config.dropout_rate
        ).to(self.device)
    
    # ---------------------------------------------------------------------------
    def show_model(self):
        print("\n\n------------------------------->Model Configuration<------------------------------------")
        print("\n", self.model)
        
    def train(self):
        pass

In [22]:
try:
    config = ConfigurationManager()

    model_trainer_config = config.get_model_trainer_config()
    data_transfromation = DataTransformation(model_trainer_config)
    data_transfromation.data_augmentation()

    train_dataset, test_dataset = data_transfromation.transformed_dataset()

    model_trainer = Model_trainer(model_trainer_config , train_dataset=train_dataset , test_dataset= test_dataset)

    model_trainer.show_model()

except Exception as e:
  raise e

[2025-08-08 15:54:57,959 : INFO : common  : yaml file config\config.yaml was read succesfully]
[2025-08-08 15:54:57,963 : INFO : common  : yaml file params.yaml was read succesfully]
[2025-08-08 15:54:57,965 : INFO : common  : Created directory at : artifacts]
[2025-08-08 15:54:57,966 : INFO : common  : Created directory at : artifacts/model]
[2025-08-08 15:54:58,650 : INFO : 4111442926  : Train Dataset Transformed Successfully]
Train Dataset Transformed Successfully
[2025-08-08 15:54:59,276 : INFO : 4111442926  : Test Dataset Transformed Successfully]
Test Dataset Transformed Successfully


------------------------------->Model Configuration<------------------------------------

 Vision_Transformer_Class(
  (patch_embedding): PatchEmbedding(
    (projection): Conv2d(3, 256, kernel_size=(4, 4), stride=(4, 4))
  )
  (encoder_layer): Sequential(
    (0): Transformer_Encoder_Layer(
      (normalization_layer_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (multi_head_atten