# Temp Notes

## Running Inference on Raw TIMM

In [None]:
import timm
import torch
from torchvision import transforms
from PIL import Image

# Configuration from your JSON
config = {
    "architecture": "maxvit_rmlp_pico_rw_256",
    "num_classes": 4,
    "num_features": 256,
    "global_pool": "avg",
    "pretrained_cfg": {
        "tag": "sw_in1k",
        "custom_load": False,
        "input_size": [3, 256, 256],
        "fixed_input_size": True,
        "interpolation": "bicubic",
        "crop_pct": 0.95,
        "crop_mode": "center",
        "mean": [0.5, 0.5, 0.5],
        "std": [0.5, 0.5, 0.5],
        "num_classes": 1000,
        "pool_size": [8, 8],
        "first_conv": "stem.conv1",
        "classifier": "head.fc"
    }
}

# Path to your checkpoint file
checkpoint_path = "checkpoint.bin"

# Instantiate the model
model = timm.create_model(
    config["architecture"],  # Model architecture
    pretrained=False,        # Skip loading pretrained weights from timm
    num_classes=config["num_classes"],  # Adjust final layer for 4 classes
    global_pool=config["global_pool"]   # Set global pooling
)

# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))  # Adjust map_location as needed
if "state_dict" in checkpoint:
    state_dict = checkpoint["state_dict"]  # For structured checkpoint files
else:
    state_dict = checkpoint

# Strip prefixes if necessary (e.g., 'module.')
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}

# Load weights into the model
model.load_state_dict(state_dict, strict=False)  # Use strict=True for strict matching

# Example preprocessing pipeline
input_size = config["pretrained_cfg"]["input_size"][1:]  # (256, 256)
mean = config["pretrained_cfg"]["mean"]
std = config["pretrained_cfg"]["std"]

transform = transforms.Compose([
    transforms.Resize(input_size, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(int(input_size[0] * config["pretrained_cfg"]["crop_pct"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# Example image (replace with your own image file path)
image_path = "example.jpg"
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

# Inference
model.eval()  # Set model to evaluation mode
with torch.no_grad():
    outputs = model(image_tensor)  # Forward pass
    predictions = torch.softmax(outputs, dim=1)  # Convert logits to probabilities

print("Predictions:", predictions)


Key Additions
Loading the Checkpoint:

torch.load(checkpoint_path) reads the checkpoint file.
The checkpoint might include either a state_dict or the weights directly. Adjust based on your checkpoint structure.
Prefix Handling:

Check if the state_dict keys have a prefix like module. (common when training with DataParallel) and strip it if necessary.
Loading into the Model:

model.load_state_dict(state_dict, strict=False) loads the weights. If the strict flag is True, it will enforce an exact match between the model's layers and the checkpoint.
Skipping Pretrained Weights:

pretrained=False ensures no weights from timm are loaded since your checkpoint provides the weights.
Things to Check
Checkpoint Compatibility:

Ensure your checkpoint was trained on the same architecture as defined in the config.
Verify the num_classes matches the checkpoint training.
Adjusting Device:

For GPU inference, move the model and data to the GPU:
python
Copy code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
image_tensor = image_tensor.to(device)

# Trackable Hyperparameter Tuning by Exposing Ray Calls

In [None]:
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.bayesopt import BayesOptSearch
from autogluon.multimodal import MultiModalPredictor

# ASHA Scheduler
asha_scheduler = ASHAScheduler(
    metric="f1_macro",  # Metric to optimize
    mode="max",         # Maximize F1-macro
    grace_period=1,     # Minimum number of epochs before stopping trials
    max_t=20            # Maximum number of epochs
)

# Bayesian Optimization Search
bayesopt_search = BayesOptSearch(
    metric="f1_macro",  # Metric to optimize
    mode="max",         # Maximize F1-macro
)

# Define the search space
search_space = {
    "optimization.learning_rate": tune.loguniform(0.00001, 0.001),
    "optimization.max_epochs": tune.randint(5, 20),
    "env.batch_size": tune.choice([8, 16, 32, 64, 128, 256]),
    "optimization.loss_function": "focal_loss",
    "optimization.focal_loss.alpha": [class_weights_list],  # Predefined class weights
    "optimization.focal_loss.gamma": tune.uniform(1, 3),
    "optimization.focal_loss.reduction": "sum",
    "model.timm_image.checkpoint_name": "efficientnet_b2",
    "optimization.optim_type": "adamw",
    "optimization.top_k_average_method": "best",
}

# Train function remains the same
def train_multimodal(config):
    predictor = MultiModalPredictor(
        label="label_column",
        problem_type="classification",
        eval_metric="f1_macro"
    )
    predictor.fit(
        train_data="path_to_train.csv",
        hyperparameters=config,
        time_limit=600
    )
    evaluation_metrics = predictor.evaluate("path_to_val.csv")
    f1_macro_score = evaluation_metrics["f1_macro"]
    tune.report(f1_macro=f1_macro_score)

# Run hyperparameter tuning with Ray
analysis = tune.run(
    train_multimodal,
    config=search_space,
    metric="f1_macro",  # Metric to optimize
    mode="max",         # Maximize F1-macro
    num_samples=10,     # Number of trials
    search_alg=bayesopt_search,
    scheduler=asha_scheduler,
    resources_per_trial={"cpu": 4, "gpu": 1},  # Adjust based on your hardware
)

# Access results
print("Best hyperparameters found:", analysis.best_config)
