In [1]:
import torch
import tqdm
import wandb

In [2]:
from src.models.CLIPArTT import NewCLIP, configure_model, Tent, clipartt_eval
from src.training.trainer import training_step, get_cost_function, get_optimizer
from src.data.dataset import get_data, base_novel_categories, split_dataset, CLASS_NAMES

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tmp_model = NewCLIP().to(device)

In [4]:

tmp_model.clip.visual = configure_model(tmp_model.clip.visual)
#optimizer = get_optimizer(tmp_model.clip.visual, learning_rate=learning_rate, weight_decay=weight_decay, momentum=momentum)

#model = Tent(model=tmp_model, optimizer=optimizer, steps=steps, episodic=True)

#get datasets
train_set, val_set, test_set = get_data(transform = tmp_model.preprocess)
base_classes, novel_classes = base_novel_categories(train_set)
train_base, train_novel = split_dataset(train_set, base_classes)
val_base, _ = split_dataset(val_set, base_classes)
test_base, test_novel = split_dataset(test_set, base_classes)

#clipartt_eval(model, val_base, base_classes, batch_size=16, device=device)

In [6]:
from itertools import product

# Define parameter grid
learning_rates = [0.001, 0.01, 0.1]
momentums = [0.9, 0.99]
weight_decays = [0.0001, 0.001]
steps_list = [3, 5, 10]

best_accuracy = 0
best_params = {}
results = []  # List to store all accuracies and parameter combinations

# Perform grid search
for lr, momentum, wd, steps in product(learning_rates, momentums, weight_decays, steps_list):
    optimizer = get_optimizer(tmp_model.clip.visual, learning_rate=lr, weight_decay=wd, momentum=momentum)
    model = Tent(model=tmp_model, optimizer=optimizer, steps=steps, episodic=True)
    
    accuracy = clipartt_eval(model, val_base, base_classes, batch_size=16, device=device)
    
    # Save the result
    results.append({"learning_rate": lr, "momentum": momentum, "weight_decay": wd, "steps": steps, "accuracy": accuracy})
    
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_params = {"learning_rate": lr, "momentum": momentum, "weight_decay": wd, "steps": steps}

print("Best Accuracy:", best_accuracy)
print("Best Parameters:", best_params)
print("All Results:", results)

100%|██████████| 32/32 [00:20<00:00,  1.59it/s]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [00:31<00:00,  1.01it/s]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [01:08<00:00,  2.14s/it]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [00:25<00:00,  1.27it/s]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [00:40<00:00,  1.26s/it]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [01:26<00:00,  2.69s/it]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [00:28<00:00,  1.12it/s]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [00:46<00:00,  1.45s/it]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [01:32<00:00,  2.90s/it]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [00:29<00:00,  1.10it/s]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [00:44<00:00,  1.40s/it]
  0%|          | 0/32 [00:00<?, ?it/s]
100%|██████████| 32/32 [01:30<00:00,  2.84s

Best Accuracy: 0.7254901960784313
Best Parameters: {'learning_rate': 0.001, 'momentum': 0.9, 'weight_decay': 0.001, 'steps': 5}
All Results: [{'learning_rate': 0.001, 'momentum': 0.9, 'weight_decay': 0.0001, 'steps': 3, 'accuracy': 0.7019607843137254}, {'learning_rate': 0.001, 'momentum': 0.9, 'weight_decay': 0.0001, 'steps': 5, 'accuracy': 0.7098039215686275}, {'learning_rate': 0.001, 'momentum': 0.9, 'weight_decay': 0.0001, 'steps': 10, 'accuracy': 0.7058823529411765}, {'learning_rate': 0.001, 'momentum': 0.9, 'weight_decay': 0.001, 'steps': 3, 'accuracy': 0.7235294117647059}, {'learning_rate': 0.001, 'momentum': 0.9, 'weight_decay': 0.001, 'steps': 5, 'accuracy': 0.7254901960784313}, {'learning_rate': 0.001, 'momentum': 0.9, 'weight_decay': 0.001, 'steps': 10, 'accuracy': 0.7156862745098039}, {'learning_rate': 0.001, 'momentum': 0.99, 'weight_decay': 0.0001, 'steps': 3, 'accuracy': 0.7098039215686275}, {'learning_rate': 0.001, 'momentum': 0.99, 'weight_decay': 0.0001, 'steps': 5, 'a




In [8]:
import json

with open("results.json", "w") as f:
    json.dump(results, f, indent=4)