In [None]:
import os
import sys
sys.path.append("../bottlenecks")
import configs
from cbm import *
from data_utils import *
from trainer_utils import *
from graph_plot_tools import *
from utils import *
from metric_utils import *
from peft import LoraConfig, get_peft_model
from typing import List, Dict, Optional
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

In [None]:
configs.set_seed(42)
device = configs.set_device(2)

In [None]:
with open("../data/cub_filtered.txt", "r") as f:
    concepts = f.read().split('\n')

In [None]:
train_loader_preprocessed  = prepared_dataloaders(Constants.cub200_link,
                                                  concepts=concepts,
                                                  prep_loaders="train",
                                                  batch_size=128,
                                                  backbone_name=Constants.clip_large_link,
                                                 )

In [None]:
val_loader_preprocessed  = prepared_dataloaders(Constants.cub200_link,
                                                  concepts=concepts,
                                                  prep_loaders="val",
                                                  batch_size=128,
                                                  backbone_name=Constants.clip_large_link,
                                                 )

In [None]:
test_loader_preprocessed  = prepared_dataloaders(Constants.cub200_link,
                                                  concepts=concepts,
                                                  prep_loaders="test",
                                                  batch_size=128,
                                                  backbone_name=Constants.clip_large_link,
                                                 )

In [None]:
config = CBMConfig(
    num_nets=2,
    num_concepts=len(concepts),
    num_classes=200,
    run_name="demo_run",
    net_types=["base", "base"],
    backbones=[Constants.clip_large_link, Constants.clip_large_link],
    displayed_names=["CLIP L/14, gumbel, 3e-4", "CLIP L/14, contrastive, 3e-4"],
    training_methods=["gumbel", "contrastive"],
    optimizers=["Adam", "Adam"],
    lrs=[3e-4, 3e-4],
    cbl_lrs=[3e-4, 3e-4],
    train_backbones=[False, False],
    lora_connections=[],
)

In [None]:
trainer = BottleneckTrainer(
    config,
    train_loader_preprocessed,
    val_loader_preprocessed,
    test_loader_preprocessed,
    num_epochs=10,
    device=device,
)

In [None]:
trainer.train()

In [None]:
trainer.test()