In [None]:
import torch
import torchvision
from torchvision import models
from tqdm import tqdm
import wandb
import gc

In [None]:
!wandb login

In [None]:
class CNN_model(torch.nn.Module):
    """
    CNN model using ResNet50 as the base model.
    The last 'trainable_layers' layers are trainable, the rest are frozen.
    """
    def __init__(self, num_classes=10, trainable_layers=1):
        super(CNN_model, self).__init__()
        self.model = models.resnet50(weiweights=models.ResNet50_Weights.IMAGENET1K_V1)
        
        # Freeze all layers
        for param in self.model.parameters():
            param.requires_grad = False
        
        # Unfreeze the last 'trainable_layers' layers
        children = list(self.model.children())
     
        if trainable_layers > len(children):
            trainable_layers = len(children)
        
        for layer in children[-trainable_layers:]:
            for param in layer.parameters():
                param.requires_grad = True
        
        # Replace the final fully connected layer
        num_features = self.model.fc.in_features
        self.model.fc = torch.nn.Linear(num_features, num_classes)

    # Forward pass through the model
    def forward(self, x):
        return self.model(x)

class Dataset(torch.utils.data.Dataset):
    """
    This class is used to load the dataset and apply data augmentation if required.
    It uses the ImageFolder class from torchvision to load the images and their labels.
    """
    def __init__(self, data_dir, input_size=(224,224), data_augmentation=False):

        super(Dataset, self).__init__()
        # This function is used to convert the input into a tuple if it is not already a tuple
        def make_tuple(a):
            if isinstance(a, tuple):
                return a
            else:
                return (a,a)
            
        # This function is used to apply data augmentation if required
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(make_tuple(input_size)),
            torchvision.transforms.RandomHorizontalFlip() if data_augmentation else torchvision.transforms.Lambda(lambda x: x),
            torchvision.transforms.RandomRotation(20) if data_augmentation else torchvision.transforms.Lambda(lambda x: x),
            torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1) if data_augmentation else torchvision.transforms.Lambda(lambda x: x),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.4716, 0.4602, 0.3899], std=[0.2382, 0.2273, 0.2361])
        ])
        self.data = torchvision.datasets.ImageFolder(root=data_dir, transform=transform)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

    def get_classes(self):
        return self.data.classes

    def stratified_split(self, val_ratio=0.2, seed=42):
        # Create a dictionary to map each class to its sample indices
        class_to_indices = {}
        for idx, (_, label) in enumerate(self.data.samples):
            if label not in class_to_indices:
                class_to_indices[label] = []
            class_to_indices[label].append(idx)
        
        # Split indices for each class
        train_indices = []
        val_indices = []
        
        generator = torch.Generator().manual_seed(seed)
        
        for cls in class_to_indices:
            indices = class_to_indices[cls]
            n = len(indices)
            n_val = int(n * val_ratio)
            
            # Shuffle indices for this class
            permuted_indices = torch.randperm(n, generator=generator).tolist()
            
            # Assign training and validation indices
            val_indices.extend([indices[i] for i in permuted_indices[:n_val]])
            train_indices.extend([indices[i] for i in permuted_indices[n_val:]])
        
        # Create Subsets
        train_subset = torch.utils.data.Subset(self.data, train_indices)
        val_subset = torch.utils.data.Subset(self.data, val_indices)
        
        return train_subset, val_subset

def test_CNN_model(model, test_loader, device, test_logging = True):
    """
    This function is used to test the model on the test set.
    It calculates the accuracy and loss of the model on the test set.
    """
    model.eval()
    criterion = torch.nn.CrossEntropyLoss()
    test_loss = 0.0
    correct = 0
    total = 0

    # Disable gradient calculation for testing
    with torch.no_grad():
        # Iterate through the test data
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Accumulate loss and correct predictions
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate accuracy
    accuracy = 100 * correct / total

    # Log test results to wandb if required
    if test_logging:
        wandb.log({
            "test_accuracy": accuracy, 
            "test_loss": test_loss / len(test_loader) 
        })
    
    return accuracy, test_loss / len(test_loader)

def train_CNN_model(model, train_loader, val_loader, learning_rate, epochs=10, device='cuda', patience=3):
    """
    This function is used to train the model with the given parameters.
    It uses the Adam optimizer and CrossEntropy loss function.
    It also implements early stopping based on validation loss.
    """
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0

    # Training loop
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Iterate through the training data
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
            images = images.to(device)
            labels = labels.to(device)
            
            scores = model(images)
            loss = criterion(scores, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Get the predicted class
            _, predicted = torch.max(scores.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_accuracy = 100 * correct / total

        # Validate the model on the validation set
        val_accuracy, val_loss = test_CNN_model(model, val_loader, device, test_logging=False)
        
        # Log training and validation metrics to wandb
        wandb.log({
            "train_loss": running_loss / len(train_loader),
            "train_accuracy": train_accuracy,
            "validation_accuracy":val_accuracy,
            "validation_loss": val_loss,
            "epoch": epoch+1
        })

        # Early stopping based on validation loss
        if val_loss < best_val_loss - 1e-3:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break

    # Load the best model state if available
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model

In [None]:
# this function is used to get the config value from wandb or command line arguments
def get_config_value(config, args, key, default=None):
    return getattr(config, key, getattr(args, key, default))

def train_model(config=None):
    # Set default hyperparameters
    defaults = {
        'batch_size': 64,
        'input_size': 224,
        'epochs': 7,
        'patience': 2,
        'data_augmentation': True,
        'trainable_layers': 3,
        'learning_rate': 0.00006239194756311145,
        'dataset_train': "/kaggle/input/da6401-as2-dataset/inaturalist_12K/train",  # Update this path
        'dataset_test': "/kaggle/input/da6401-as2-dataset/inaturalist_12K/val"     # Update this path
    }
    
    # Initialize wandb with the provided entity and project
    with wandb.init(entity='me21b138-indian-institute-of-technology-madras', project='AS2', config=config):
        # If there's a wandb config, use it, otherwise use defaults
        config = wandb.config
        
        # Create a class to mimic argparse for the helper functions
        class Args:
            def __init__(self, **kwargs):
                for key, value in kwargs.items():
                    setattr(self, key, value)
        
        # Set up args with defaults
        args = Args(**defaults)
        
        # This class is used to load the dataset and apply data augmentation if required
        dataset_train_val = Dataset(
            data_dir=args.dataset_train,
            input_size=args.input_size,
            data_augmentation=get_config_value(config, args, 'data_augmentation')
        )
        
        # This function is used to split the dataset into train and validation set in a stratified manner
        train_subset, val_subset = dataset_train_val.stratified_split(val_ratio=0.2)
        
        # This function is used to create the data loaders for train and validation set
        train_loader = torch.utils.data.DataLoader(
            train_subset, 
            batch_size=get_config_value(config, args, "batch_size"), 
            shuffle=True
        )
        val_loader = torch.utils.data.DataLoader(
            val_subset, 
            batch_size=get_config_value(config, args, "batch_size"), 
            shuffle=False
        )

        # This class is used to create the CNN model with the given parameters
        run_name = f"lr_{get_config_value(config, args, 'learning_rate')}_trainable_{get_config_value(config, args, 'trainable_layers')}"
        wandb.run.name = run_name

        # This class is used to create the CNN model with the given parameters
        model = CNN_model(
            num_classes = len(dataset_train_val.get_classes()),
            trainable_layers=get_config_value(wandb.config, args, 'trainable_layers')
        )
        
        # Check if GPU is available and use it, otherwise use CPU
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        print(device)
        
        # The name is created using the parameters used in the model
        model = train_CNN_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            learning_rate=get_config_value(config, args, 'learning_rate'),
            epochs=get_config_value(config, args, 'epochs'),
            device=device,
            patience=get_config_value(config, args, 'patience')
        )

                # This class is used to load the dataset and apply data augmentation if required
        dataset_test = Dataset(data_dir=args.dataset_test, input_size=args.input_size, data_augmentation=get_config_value(wandb.config, args, 'data_augmentation'))

        # This function is used to create the data loaders for train and validation set
        test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=get_config_value(wandb.config, args, 'batch_size'), shuffle=False)
    
        train_accuracy, train_loss = test_CNN_model(
            model=model,
            test_loader=train_loader,
            device=device
        )
        print(f"Train Accuracy: {train_accuracy}, Train Loss: {train_loss}")

        val_accuracy, val_loss = test_CNN_model(
            model=model,
            test_loader=val_loader,
            device=device
        )
        print(f"Validation Accuracy: {val_accuracy}, Validation Loss: {val_loss}")
    
        # This function is used to test the model with the given parameters
        test_accuracy, test_loss = test_CNN_model(
            model=model,
            test_loader=test_loader,
            device=device
        )
        print(f"Test Accuracy: {test_accuracy}, Test Loss: {test_loss}")

        # empty the cache and collect garbage to free up memory
        torch.cuda.empty_cache()
        gc.collect()

In [5]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'validation_accuracy',
        'goal': 'maximize'
    },
    'parameters': {
        'data_augmentation': {
            'values': [True, False]
        },
        'batch_size': {
            'values': [32, 64]
        },
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 0.00001,
            'max': 0.01
        },
        'trainable_layers': {
            'values': [1, 2, 3]
        }
    }
}

In [None]:
# Configuration for the sweep
entity = 'me21b138-indian-institute-of-technology-madras'  # Your wandb entity
project = 'AS2'  # Your wandb project
count = 100  # Number of runs to execute

# Initialize the sweep
wandb.require("core")
sweep_id = wandb.sweep(sweep_config, entity=entity, project=project)

# Start the sweep agent
wandb.agent(sweep_id, function=train_model, count=count)