## Efficieny improvement to Vision Mamba 
This experiment is conducted by training full image set and then also dropping patches of the images.

Structure and the Mamba model in this code is inspired by the "L. Zhu, B. Liao, Q. Zhang, X. Wang, W. Liu, and X. Wang, “Vision Mamba: Efficient Visual
Representation Learning with Bidirectional State Space Model,” Computer Vision and Pattern
Recognition, 1 2024. [Online]. Available: https://github.com/hustvl/Vim"

### Download the Cats and Dogs data from Kaggle

In [None]:
# import kagglehub

# # Download the Cats vs Dogs dataset
# path = kagglehub.dataset_download("abhinavnayak/catsvdogs-transformed")

# print("Path to dataset files:", path)

                  ###################################### Code Starts here ##############################################

### Load libraries

In [None]:
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import utils
import sys
import os
import zipfile
import pandas as pd
import models_mamba

from pathlib import Path
from timm.data import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma

# from datasets import build_dataset
from engine import train_one_epoch, evaluate
from losses import DistillationLoss
from samplers import RASampler
from augment import new_data_aug_generator
from contextlib import suppress

from torchvision import transforms
from fvcore.nn import FlopCountAnalysis


# log about
import mlflow



device = torch.device('cpu')

In [None]:
def get_args_parser():
    # Remove the Jupyter-specific arguments
    sys.argv = sys.argv[:1]  # Keep only the script name (the first argument)
    
    parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
    parser.add_argument('--batch-size', default=64, type=int)
    parser.add_argument('--epochs', default=20, type=int)
    parser.add_argument('--bce-loss', action='store_true')
    parser.add_argument('--unscale-lr', action='store_true')

    # Model parameters
    parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
                        help='Name of model to train') # default='deit_base_patch16_224',
    parser.add_argument('--input-size', default=224, type=int, help='images input size') #default=224

    parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
                        help='Dropout rate (default: 0.)')
    parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
                        help='Drop path rate (default: 0.1)')

    parser.add_argument('--model-ema', action='store_true')
    parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
    parser.set_defaults(model_ema=True)
    parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
    parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')

    # Optimizer parameters
    parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
                        help='Optimizer (default: "adamw"')
    parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
                        help='Optimizer Epsilon (default: 1e-8)')
    parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
                        help='Optimizer Betas (default: None, use opt default)')
    parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
                        help='Clip gradient norm (default: None, no clipping)')
    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=0.05,
                        help='weight decay (default: 0.05)')
    # Learning rate schedule parameters
    parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
                        help='LR scheduler (default: "cosine"')
    parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
                        help='learning rate (default: 5e-4)')
    parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
                        help='learning rate noise on/off epoch percentages')
    parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
                        help='learning rate noise limit percent (default: 0.67)')
    parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
                        help='learning rate noise std-dev (default: 1.0)')
    parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
                        help='warmup learning rate (default: 1e-6)')
    parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')

    parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
                        help='epoch interval to decay LR')
    parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
                        help='epochs to warmup LR, if scheduler supports')
    parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
                        help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
    parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
                        help='patience epochs for Plateau LR scheduler (default: 10')
    parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
                        help='LR decay rate (default: 0.1)')

    # Augmentation parameters
    parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
                        help='Color jitter factor (default: 0.3)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
                        help='Use AutoAugment policy. "v0" or "original". " + \
                             "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
    parser.add_argument('--train-interpolation', type=str, default='bicubic',
                        help='Training interpolation (random, bilinear, bicubic default: "bicubic")')

    parser.add_argument('--repeated-aug', action='store_true')
    parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
    parser.set_defaults(repeated_aug=True)
    
    parser.add_argument('--train-mode', action='store_true')
    parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
    parser.set_defaults(train_mode=True)
    
    parser.add_argument('--ThreeAugment', action='store_true') #3augment
    
    parser.add_argument('--src', action='store_true') #simple random crop
    
    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
                        help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel',
                        help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1,
                        help='Random erase count (default: 1)')
    parser.add_argument('--resplit', action='store_true', default=False,
                        help='Do not random erase first (clean) augmentation split')

    # * Mixup params
    parser.add_argument('--mixup', type=float, default=0.8,
                        help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
    parser.add_argument('--cutmix', type=float, default=1.0,
                        help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
    parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
    parser.add_argument('--mixup-prob', type=float, default=1.0,
                        help='Probability of performing mixup or cutmix when either/both is enabled')
    parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
    parser.add_argument('--mixup-mode', type=str, default='batch',
                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')

    # Distillation parameters
    parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
                        help='Name of teacher model to train (default: "regnety_160"')
    parser.add_argument('--teacher-path', type=str, default='')
    parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
    parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
    parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
    
    # * Cosub params
    parser.add_argument('--cosub', action='store_true') 
    
    # * Finetuning params
    parser.add_argument('--finetune', default='', help='finetune from checkpoint')
    parser.add_argument('--attn-only', action='store_true') 
    
    # Dataset parameters
    # parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
    #                     help='dataset path')
    # parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
    #                     type=str, help='Image Net dataset path')
    parser.add_argument('--inat-category', default='name',
                        choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
                        type=str, help='semantic granularity')

    parser.add_argument('--output_dir', default='',
                        help='path where to save, empty for no saving')
    # parser.add_argument('--device', default='cuda',
    #                     help='device to use for training / testing')
    parser.add_argument('--device', default='cpu',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
    parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
    parser.add_argument('--num_workers', default=10, type=int)
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    
    # amp about
    parser.add_argument('--if_amp', action='store_true')
    parser.add_argument('--no_amp', action='store_false', dest='if_amp')
    parser.set_defaults(if_amp=False)

    # if continue with inf
    parser.add_argument('--if_continue_inf', action='store_true')
    parser.add_argument('--no_continue_inf', action='store_false', dest='if_continue_inf')
    parser.set_defaults(if_continue_inf=False)

    # if use nan to num
    parser.add_argument('--if_nan2num', action='store_true')
    parser.add_argument('--no_nan2num', action='store_false', dest='if_nan2num')
    parser.set_defaults(if_nan2num=False)

    parser.add_argument('--nb_classes', default=2, type=int, help='Number of classes (cats and dogs)')

    parser.add_argument('--local-rank', default=0, type=int)
    return parser

### Dataset path

In [None]:
# Path where all images are stored
image_path = "C:/Users/mahagam3/Documents/ECE course/Project B/Mamba/Vision_mamba_code/cats_dogs"

# Collect all image paths (assuming images are .jpg or .png)
image_files = [os.path.join(image_path, f) for f in os.listdir(image_path) if f.endswith(('.jpg', '.png'))]

### Load images and extract patches + drop patches

In [None]:
from PIL import Image

def load_images_and_extract_patches(image_folder, patch_size=16, drop_fraction=0.50): # change the drop_fraction to 0.0 to get all patches
    patches = []
    labels = []
    
    # Iterate over all image files in the folder
    for image_file in os.listdir(image_folder):
        if image_file.endswith(('.jpg', '.png')):
            image_path = os.path.join(image_folder, image_file)
            
            # Open the image
            image = Image.open(image_path)
            image = image.convert('RGB')
            
            # Extract patches from the image
            image_data = np.array(image)
            h, w, _ = image_data.shape
            image_patches = []
            for i in range(0, h - patch_size, patch_size):
                for j in range(0, w - patch_size, patch_size):
                    patch = image_data[i:i+patch_size, j:j+patch_size]
                    image_patches.append(patch)
            
            # Calculate number of patches to drop for this image
            num_patches = len(image_patches)
            num_patches_to_drop = int(num_patches * drop_fraction)
            
            # Shuffle patches and drop some of them
            image_patches = np.array(image_patches)
            indices = np.random.permutation(num_patches)
            image_patches = image_patches[indices[:-num_patches_to_drop]]  # Keep the remaining patches
            
            # Add the remaining patches to the overall list
            patches.extend(image_patches)
            
            # Extract label from the filename (e.g., 'cat 1' -> 'cat')
            label = image_file.split()[0]  # This takes the first part of the filename
            labels.extend([label] * len(image_patches))  # Repeat the label for each remaining patch
    
    # Convert to np.array if needed
    patches = np.array(patches)
    labels = np.array(labels)

    return patches, labels

### dataset building to set to train

In [None]:
def build_dataset(is_train=True, patch_size=16, batch_size=64, patches=None, labels=None, drop_fraction=0.95): # adjust drop_fraction as needed
    # Convert the patches list to a tensor
    patches_tensor = torch.stack([torch.tensor(patch, dtype=torch.float32).unsqueeze(0) for patch in patches])

    # Encode the labels if they are not already numeric
    label_encoder = LabelEncoder()
    labels_encoded = label_encoder.fit_transform(labels)

    # Convert the labels to a tensor
    labels_tensor = torch.tensor(labels_encoded, dtype=torch.long)

    # Create a dataset and dataloader
    dataset = TensorDataset(patches_tensor, labels_tensor)

    if is_train:
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
    else:
        data_loader = DataLoader(dataset, batch_size=batch_size // 2, shuffle=False, num_workers=4, drop_last=True)

    return data_loader


In [None]:
# Function to calculate the FLOPs
def calculate_flops(model, input_tensor):
    from thop import profile
    
    # Clear any existing 'total_ops' and 'total_params' attributes
    for layer in model.modules():
        if hasattr(layer, 'total_ops'):
            del layer.total_ops  # Remove existing 'total_ops'
        if hasattr(layer, 'total_params'):
            del layer.total_params  # Remove existing 'total_params'

    # Now calculate FLOPs
    flops, params = profile(model, inputs=(input_tensor,))
    return flops

In [None]:
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

# Main function
def main(args):
    # Use CPU only
    device = torch.device('cpu')

    # Load image file paths and patches
    image_folder = "C:/Users/mahagam3/Documents/ECE course/Project B/Mamba/Vision_mamba_code/cats_dogs"

    # Load patches and labels directly, applying the drop_fraction
    print("Loading dataset...")
    patches, labels = load_images_and_extract_patches(image_folder)  

    print(f"Dataset loaded with {len(patches)} patches!")

    # Convert patches to tensor without resizing
    patches_tensor = torch.stack([torch.tensor(patch, dtype=torch.float32).unsqueeze(0) for patch in patches])
    print(patches_tensor.shape)

    # Convert labels to tensor (ensure labels are integers)
    if isinstance(labels[0], str):  # Check if labels are strings
        label_map = {label: idx for idx, label in enumerate(set(labels))}
        labels = [label_map[label] for label in labels]

    labels_tensor = torch.tensor(labels, dtype=torch.long)

    # Create dataset directly from the patches and labels
    dataset_train = TensorDataset(patches_tensor, labels_tensor)

    # DataLoader setup
    data_loader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

        # Create model
    print(f"Creating model: {args.model}")
    model = create_model(
        args.model,
        pretrained=False,
        num_classes=1, #num_classes=args.nb_classes,
        drop_rate=args.drop,
        drop_path_rate=args.drop_path,
        img_size=args.input_size
    )

    model.to(device)

    # Measure FLOPs once after model initialization
    size = int((len(patches) / 2000) * 0.5) # after dropping 50% of extratced 16x16 patches
    input_tensor_data = torch.randn(size, 3, 224, 224) #input_tensor_data = torch.randn(len(patches), 3, 224, 224)
    flops_count = calculate_flops(model, input_tensor_data)
    print(f'FLOPs for Model: {flops_count:.2f} FLOPs')

    # Optimizer and scheduler
    optimizer = create_optimizer(args, model)
    lr_scheduler, _ = create_scheduler(args, optimizer)

    # Loss function
    if args.nb_classes == 2:
        # Binary classification: use BCEWithLogitsLoss
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        # Multi-class classification: use CrossEntropyLoss or LabelSmoothingCrossEntropy
        if args.label_smoothing > 0.0:
            criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
        else:
            criterion = torch.nn.CrossEntropyLoss()

    # Start training
    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    total_training_flops = 0  # Variable to accumulate total FLOPs across all epochs

    for epoch in range(args.start_epoch, args.epochs):
        epoch_flops = 0  # Variable to accumulate FLOPs for the current epoch

        # Train one epoch
        for batch_idx, (inputs, targets) in enumerate(data_loader_train):
            optimizer.zero_grad()

            # Forward pass
            inputs = inputs.squeeze(1) 
            inputs = inputs.permute(0, 3, 1, 2)
            inputs = F.interpolate(inputs, size=(224, 224), mode='bilinear', align_corners=False)
            # inputs = inputs.view(-1, 3, 224, 224) 
            inputs = inputs.to(device)
            targets = targets.to(device)
            targets = targets.float().unsqueeze(1)

            # Compute loss
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Track FLOPs during the forward pass for the current batch
            batch_flops = flops_count  # Using the previously computed FLOPs
            epoch_flops += batch_flops

        # Update learning rate
        lr_scheduler.step(epoch)

        # Optionally, print FLOPs after each epoch
        print(f"Epoch {epoch + 1}: Total FLOPs for this epoch: {epoch_flops / 1e9:.2f} GFLOPs")

        # Accumulate total FLOPs for all epochs
        total_training_flops += epoch_flops

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f'Training time: {total_time_str}')

    # Print the total FLOPs for training
    print(f"Total training FLOPs: {total_training_flops / 1e9:.2f} GFLOPs")  # Convert to GFLOPs for easier interpretation

if __name__ == '__main__':
    parser = argparse.ArgumentParser('Training and evaluation script', parents=[get_args_parser()], conflict_handler='resolve')
    args = parser.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)

           ########################################### End of the code ################################################