In [1]:
import wandb
from handsoncv.datasets import CILPFusionDataset
from handsoncv.models import LateFusionNet, IntermediateFusionNet
from handsoncv.training import train_fusion_model
from torchvision import transforms

In [None]:
# Load mapping created with notebook 01_dataset_exploration.ipynb
with open("subset_mapping.json", "r") as f:
    subset_records = json.load(f)

# Split based on saved tags
train_samples = [s for s in subset_records if "train" in s["tags"]]
val_samples = [s for s in subset_records if "val" in s["tags"]]

# Create PyTorch Datasets; note: provide the root_dir so it can find azimuth/zenith
ROOT_DATA = "~/Documents/repos/BuildingAIAgentsWithMultimodalModels/data/assessment/"
IMG_SIZE = 64
BATCH_SIZE = 32

img_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),  # Scales data into [0,1]
])

train_ds = CILPFusionDataset(train_samples, root_dir=os.path.expanduser(ROOT_DATA), transform=transform)
val_ds = CILPFusionDataset(val_samples, root_dir=os.path.expanduser(ROOT_DATA), transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

print(f"Ready to train with {len(train_ds)} training pairs.")

In [None]:
# 1. Setup Data (10% subset logic from your previous cell)
# ... [Assuming 'subset' is a list of your 10% pairs] ...
train_size = int(0.8 * len(subset))
train_samples = subset[:train_size]
val_samples = subset[train_size:]

train_ds = CILPFusionDataset(train_samples, transform=transforms.ToTensor())
val_ds = CILPFusionDataset(val_samples, transform=transforms.ToTensor())

# 2. Define Experiment Suite
strategies = [
    ("Late Fusion", LateFusionNet()),
    ("Int Fusion Concat", IntermediateFusionNet(mode='concat')),
    ("Int Fusion Add", IntermediateFusionNet(mode='add')),
    ("Int Fusion Mul", IntermediateFusionNet(mode='mul')),
]

results = []

for name, model in strategies:
    wandb.init(project="handsoncv-fusion", name=name)
    print(f"Training {name}...")
    
    metrics = train_fusion_model(
        model, 
        train_dataloader, 
        valid_dataloader, 
        optimizer=torch.optim.Adam(model.parameters(), lr=1e-4),
        criterion=torch.nn.CrossEntropyLoss(),
        device="cuda"
    )
    
    metrics['Architecture'] = name
    results.append(metrics)
    wandb.finish()

# 3. Final Comparison Table
import pandas as pd
df = pd.DataFrame(results)
print(df)