In [1]:
import torch
import torch.utils.data
import torchvision
import wandb
import yaml
import argparse
from tqdm import tqdm
import gc

In [None]:
!wandb login

wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\shrey\_netrc
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin


In [2]:
class CNN_model(torch.nn.Module):
    """
    A Convolutional Neural Network (CNN) model for image classification.
    This model consists of several convolutional blocks followed by a fully connected layer.
    The architecture is flexible and can be customized with different parameters.
    """
    def __init__(
            self, input_size, output_size, 
            num_filters, kernel_sizes, pool_kernels, 
            paddings, conv_strides, dense_layer, 
            activation_fn, use_softmax=False, 
            batch_norm=True, dropout_rate=0.0
        ):
        
        super(CNN_model, self).__init__()

        self.conv_blocks = torch.nn.ModuleList()
        self.dropout_rate = dropout_rate

        # This function is used to identify the activation function based on the name provided
        def identify_activation(function_name):
            activation = {
                'relu': torch.nn.ReLU(),
                'sigmoid': torch.nn.Sigmoid(),
                'tanh': torch.nn.Tanh(),
                'selu': torch.nn.SELU(),
                'gelu': torch.nn.GELU(),
                'mish': torch.nn.Mish(),
                'leakyrelu': torch.nn.LeakyReLU()
            }
            return activation.get(function_name.lower(), torch.nn.ReLU())
        
        # This function is used to convert the input into a tuple if it is not already a tuple
        def make_tuple(a):
            if isinstance(a, tuple):
                return a
            else:
                return (a,a)

        for i in range(len(num_filters)):
            # Create a convolutional block with the specified parameters
            block = torch.nn.Sequential(
                torch.nn.Conv2d(
                    in_channels=3 if i == 0 else num_filters[i-1], 
                    out_channels=num_filters[i], 
                    kernel_size=make_tuple(kernel_sizes[i]), 
                    padding=make_tuple(paddings[i]), 
                    stride=make_tuple(conv_strides[i])
                ),
                torch.nn.BatchNorm2d(num_filters[i]) if batch_norm else torch.nn.Identity(),
                identify_activation(activation_fn),
                torch.nn.Dropout2d(p=dropout_rate) if dropout_rate>0 else torch.nn.Identity(),
                torch.nn.MaxPool2d(kernel_size=make_tuple(pool_kernels[i]), stride=make_tuple(pool_kernels[i]))
            )
            self.conv_blocks.append(block)
        
        # Calculate the output size after all convolutional and pooling layers
        h, w = make_tuple(input_size)
        for i in range(len(num_filters)):
            f_h, f_w = make_tuple(kernel_sizes[i])
            s_h, s_w = make_tuple(conv_strides[i])
            p_h, p_w = make_tuple(paddings[i])

            h = ((h - f_h + 2*p_h)//s_h) + 1
            w = ((w - f_w + 2*p_w)//s_w) + 1

            pp_h, pp_w = make_tuple(pool_kernels[i])
            ps_h, ps_w = make_tuple(pool_kernels[i])

            h = ((h - pp_h)//ps_h) + 1
            w = ((w - pp_w)//ps_w) + 1
        
        # Create the fully connected layer with the specified parameters
        self.dense_layer = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(in_features=num_filters[-1]*h*w, out_features=dense_layer, bias=True),
            torch.nn.BatchNorm1d(dense_layer) if batch_norm else torch.nn.Identity(),
            identify_activation(activation_fn),
            torch.nn.Dropout(p=dropout_rate) if dropout_rate>0 else torch.nn.Identity(),
            torch.nn.Linear(in_features=dense_layer, out_features=output_size),
        )
        # Set the activation function for the output layer
        self.use_softmax = use_softmax
        if use_softmax:
            self.softmax_layer = torch.nn.Softmax(dim=1)
    
    # Forward pass through the model
    def forward(self, x):
        for block in self.conv_blocks:
            x = block(x)
        
        x = self.dense_layer(x)

        if self.use_softmax:
            x = self.softmax_layer(x)
        
        return x

class Dataset(torch.utils.data.Dataset):
    """
    This class is used to load the dataset and apply data augmentation if required.
    It uses the ImageFolder class from torchvision to load the images and their labels.
    """
    def __init__(self, data_dir, input_size=(224,224), data_augmentation=False):

        super(Dataset, self).__init__()
        # This function is used to convert the input into a tuple if it is not already a tuple
        def make_tuple(a):
            if isinstance(a, tuple):
                return a
            else:
                return (a,a)
            
        # This function is used to apply data augmentation if required
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(make_tuple(input_size)),
            torchvision.transforms.RandomHorizontalFlip() if data_augmentation else torchvision.transforms.Lambda(lambda x: x),
            torchvision.transforms.RandomRotation(20) if data_augmentation else torchvision.transforms.Lambda(lambda x: x),
            torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1) if data_augmentation else torchvision.transforms.Lambda(lambda x: x),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.4716, 0.4602, 0.3899], std=[0.2382, 0.2273, 0.2361])
        ])
        self.data = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

    def get_classes(self):
        return self.data.classes

    def stratified_split(self, val_ratio=0.2, seed=42):
        # Create a dictionary to map each class to its sample indices
        class_to_indices = {}
        for idx, (_, label) in enumerate(self.data.samples):
            if label not in class_to_indices:
                class_to_indices[label] = []
            class_to_indices[label].append(idx)
        
        # Split indices for each class
        train_indices = []
        val_indices = []
        
        generator = torch.Generator().manual_seed(seed)
        
        for cls in class_to_indices:
            indices = class_to_indices[cls]
            n = len(indices)
            n_val = int(n * val_ratio)
            
            # Shuffle indices for this class
            permuted_indices = torch.randperm(n, generator=generator).tolist()
            
            # Assign training and validation indices
            val_indices.extend([indices[i] for i in permuted_indices[:n_val]])
            train_indices.extend([indices[i] for i in permuted_indices[n_val:]])
        
        # Create Subsets
        train_subset = torch.utils.data.Subset(self.data, train_indices)
        val_subset = torch.utils.data.Subset(self.data, val_indices)
        
        return train_subset, val_subset
    
def test_CNN_model(model, test_loader, device, test_logging = True):
    """
    This function is used to test the model on the test set.
    It calculates the accuracy and loss of the model on the test set.
    """
    model.eval()
    criterion = torch.nn.CrossEntropyLoss()
    test_loss = 0.0
    correct = 0
    total = 0

    # Disable gradient calculation for testing
    with torch.no_grad():
        # Iterate through the test data
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Accumulate loss and correct predictions
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate accuracy
    accuracy = 100 * correct / total

    # Log test results to wandb if required
    if test_logging:
        wandb.log({
            "test_accuracy": accuracy, 
            "test_loss": test_loss / len(test_loader) 
        })
    
    return accuracy, test_loss / len(test_loader)

def train_CNN_model(model, train_loader, val_loader, learning_rate, epochs, device, patience = 3):
    """
    This function is used to train the model with the given parameters.
    It uses the Adam optimizer and CrossEntropy loss function.
    It also implements early stopping based on validation loss.
    """
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0

    # Training loop
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Iterate through the training data
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            images = images.to(device)
            labels = labels.to(device)

            scores = model(images)
            loss = criterion(scores, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Get the predicted class
            _, predicted = torch.max(scores.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_accuracy = 100 * correct / total

        # Validate the model on the validation set
        val_accuracy, val_loss = test_CNN_model(model, val_loader, device, test_logging=False)
        
        # Log training and validation metrics to wandb
        wandb.log({
            "train_loss": running_loss / len(train_loader),
            "train_accuracy": train_accuracy,
            "validation_accuracy":val_accuracy,
            "validation_loss": val_loss,
            "epoch": epoch+1
        })

        # Early stopping based on validation loss
        if val_loss < best_val_loss - 1e-3:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break
    
    # Load the best model state if available
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model

In [3]:
# this function is used to get the config value from wandb or command line arguments
def get_config_value(config, args, key, default=None):
    return getattr(config, key, getattr(args, key, default))

def train_model(config=None):
    # Set default hyperparameters
    defaults = {
        'batch_size': 64,
        'input_size': 224,
        'epochs': 10,
        'patience': 3,
        'data_augmentation': False,
        'number_of_filters': [64, 128, 256, 512, 1024],
        'kernel_sizes': [3, 3, 3, 5, 5],
        'pool_kernels': [3, 3, 2, 2, 2],
        'paddings': [1, 1, 1, 1, 1],
        'conv_strides': [1, 1, 1, 1, 1],
        'dense_layer': 256,
        'activation_fn': 'gelu',
        'use_softmax': 1,
        'batch_norm': 1,
        'dropout_rate': 0.025831017228233916,
        'learning_rate': 0.0001329901471283676,
        'dataset_train': "C:/Users/shrey/Desktop/ACAD/DL/nature_12K/inaturalist_12K/train",
        'dataset_test': "C:/Users/shrey/Desktop/ACAD/DL/nature_12K/inaturalist_12K/val"
    }
    
    # Initialize wandb with the provided entity and project
    with wandb.init(entity='me21b138-indian-institute-of-technology-madras', project='AS2', config=config):
        # If there's a wandb config, use it, otherwise use defaults
        config = wandb.config
        
        # Create a class to mimic argparse for the helper functions
        class Args:
            def __init__(self, **kwargs):
                for key, value in kwargs.items():
                    setattr(self, key, value)
        
        # Set up args with defaults
        args = Args(**defaults)
        
        # This class is used to load the dataset and apply data augmentation if required
        dataset_train_val = Dataset(
            data_dir=args.dataset_train,
            input_size=args.input_size,
            data_augmentation=get_config_value(config, args, 'data_augmentation')
        )
        
        # This function is used to split the dataset into train and validation set in a stratified manner
        train_subset, val_subset = dataset_train_val.stratified_split(val_ratio=0.2)
        
        # This function is used to create the data loaders for train and validation set
        train_loader = torch.utils.data.DataLoader(
            train_subset, 
            batch_size=get_config_value(config, args, "batch_size"), 
            shuffle=True
        )
        val_loader = torch.utils.data.DataLoader(
            val_subset, 
            batch_size=get_config_value(config, args, "batch_size"), 
            shuffle=False
        )

        # The name is created using the parameters used in the model
        run_name = f"lr_{get_config_value(config, args, 'learning_rate')}_1filters_{get_config_value(config, args, 'number_of_filters')[0]}_5filters_{get_config_value(config, args, 'number_of_filters')[-1]}_dense_{get_config_value(config, args, 'dense_layer')}_activation_{get_config_value(config, args, 'activation_fn')}"
        wandb.run.name=run_name
    
        # This class is used to create the CNN model with the given parameters
        model = CNN_model(
            input_size=args.input_size,
            output_size=len(dataset_train_val.get_classes()),
            num_filters=get_config_value(config, args, 'number_of_filters'),
            kernel_sizes=get_config_value(config, args, 'kernel_sizes'),
            pool_kernels=get_config_value(config, args, 'pool_kernels'),
            paddings=get_config_value(config, args, 'paddings'),
            conv_strides=get_config_value(config, args, 'conv_strides'),
            dense_layer=get_config_value(config, args, 'dense_layer'),
            activation_fn=get_config_value(config, args, 'activation_fn'),
            use_softmax=get_config_value(config, args, 'use_softmax'),
            batch_norm=get_config_value(config, args, 'batch_norm'),
            dropout_rate=get_config_value(config, args, 'dropout_rate')
        )

        # Check if GPU is available and use it, otherwise use CPU
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        
        # This function is used to train the model with the given parameters
        model = train_CNN_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            learning_rate=get_config_value(config, args, 'learning_rate'),
            epochs=get_config_value(config, args, 'epochs'),
            device=device,
            patience=get_config_value(config, args, 'patience')
        )

        # This class is used to load the dataset and apply data augmentation if required
        dataset_test = Dataset(data_dir=args.dataset_test, input_size=args.input_size, data_augmentation=get_config_value(wandb.config, args, 'data_augmentation'))

        # This function is used to create the data loaders for train and validation set
        test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=get_config_value(wandb.config, args, 'batch_size'), shuffle=False)
    
        # This function is used to test the model with the given parameters
        test_accuracy, test_loss = test_CNN_model(
            model=model,
            test_loader=test_loader,
            device=device
        )
        
        # empty the cache and collect garbage to free up memory
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
train_model()

In [9]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'validation_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'data_augmentation': {
            'values': [True, False]
        },
        'number_of_filters': {
            'values': [
                [256, 128, 64, 32, 16],
                [512, 256, 128, 64, 32],
                [32, 32, 32, 32, 32],
                [64, 64, 64, 64, 64],
                [16, 32, 64, 128, 256],
                [32, 64, 128, 256, 512],
                [64, 128, 256, 512, 1024]
            ]
        },
        'kernel_sizes': {
            'values': [
                [3, 3, 3, 3, 3],
                [5, 5, 3, 3, 3],
                [5, 5, 5, 5, 5]
            ]
        },
        'pool_kernels': {
            'values': [
                [2, 2, 2, 2, 2],
                [3, 3, 2, 2, 2]
            ]
        },
        'paddings': {
            'values': [
                [1, 1, 1, 1, 1]
            ]
        },
        'conv_strides': {
            'values': [
                [1, 1, 1, 1, 1]
            ]
        },
        'dense_layer': {
            'values': [256, 512, 1024]
        },
        'activation_fn': {
            'values': ["relu", "tanh", "selu", "gelu", "mish", "leakyrelu"]
        },
        'batch_norm': {
            'values': [0, 1]
        },
        'dropout_rate': {
            'distribution': 'uniform',
            'max': 0.75,
            'min': 0
            
        },
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 0.0001,
            'max': 0.1
        }
    }
}

In [None]:
# Configuration for the sweep
entity = 'me21b138-indian-institute-of-technology-madras'  # Your wandb entity
project = 'AS2'  # Your wandb project
count = 100  # Number of runs to execute

# Initialize the sweep
wandb.require("core")
sweep_id = wandb.sweep(sweep_config, entity=entity, project=project)

# Start the sweep agent
wandb.agent(sweep_id, function=train_model, count=count)