In [None]:
!pip install fiftyone
!pip install wandb

In [None]:
!wandb login

In [None]:
import fiftyone as fo
from fiftyone.utils.huggingface import load_from_hub
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torch.utils.data import DataLoader

from src import datasets
from src import models
from src import training
from src import visualization

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data 

In [None]:
# load fiftyone dataset from huggingface
dataset = load_from_hub("MatthiasCr/multimodal-shapes-subset", 
                         name="multimodal-shapes-subset",
                         # fewer workers and greater batch size to hopefully avoid getting rate limited
                         num_workers=2,
                         batch_size=200
                        )

In [None]:
IMG_SIZE = 64
BATCH_SIZE = 32
EPOCHS = 20

In [None]:
img_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),  # Scales data into [0,1] TODO correct non deprecated version
])

train_dataset = datasets.MultimodalDataset(dataset, "train", img_transforms)
val_dataset = datasets.MultimodalDataset(dataset, "val", img_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

steps_per_epoch = len(train_dataloader)

## Late Fusion

In [None]:
loss_func = nn.BCEWithLogitsLoss()

In [None]:
late_model = models.LateFusionModel().to(device)
late_num_parameters = sum(p.numel() for p in late_model.parameters())
late_optim = Adam(late_model.parameters(), lr=1e-3)

late_scheduler = CosineAnnealingLR(
    late_optim,
    T_max=EPOCHS * steps_per_epoch,
    eta_min=1e-6
)

In [None]:
run = training.initWandbRun("late", EPOCHS, BATCH_SIZE, late_num_parameters)

late_train_loss, late_val_loss = training.train_model(
    late_model, late_optim, late_scheduler, loss_func, EPOCHS, train_dataloader, val_dataloader, device, run
)

run.finish()

visualization.plot_loss(EPOCHS,
    {
        "Late train Loss": late_train_loss,
        "Late Val Loss": late_val_loss
    }
)

## Intermediate Fusion

In [None]:
cat_model = models.IntermediateFusionNet(fusionType="cat").to(device)
cat_num_parameters = sum(p.numel() for p in cat_model.parameters())
cat_optim = Adam(cat_model.parameters(), lr=0.0001)

add_model = models.IntermediateFusionNet(fusionType="add").to(device)
add_num_parameters = sum(p.numel() for p in add_model.parameters())
add_optim = Adam(add_model.parameters(), lr=0.0001)

had_model = models.IntermediateFusionNet(fusionType="hadamard").to(device)
had_num_parameters = sum(p.numel() for p in had_model.parameters())
had_optim = Adam(had_model.parameters(), lr=0.0001)

In [None]:
run = training.initWandbRun("intermediate (concatenation)", EPOCHS, BATCH_SIZE, cat_num_parameters)

cat_train_loss, cat_val_loss = training.train_model(
    cat_model, cat_optim, loss_func, EPOCHS, train_dataloader, val_dataloader, device, run
)

run.finish()

In [None]:
run = training.initWandbRun("intermediate (addition)", EPOCHS, BATCH_SIZE, add_num_parameters)

add_train_loss, add_val_loss = training.train_model(
    add_model, add_optim, loss_func, EPOCHS, train_dataloader, val_dataloader, device, run
)

run.finish()

In [None]:
run = training.initWandbRun("intermediate (hadamard)", EPOCHS, BATCH_SIZE, had_num_parameters)

had_train_loss, had_val_loss = training.train_model(
    had_model, had_optim, loss_func, EPOCHS, train_dataloader, val_dataloader, device, run
)

run.finish()

In [None]:
visualization.plot_loss(EPOCHS,
    {
        "Concat Valid Loss": cat_val_loss,
        "Addition Valid Loss": add_val_loss,
        "Hadamard Valid Loss": had_val_loss,
        "Late Valid Loss": late_val_loss
    }
)