In [1]:
!pip install torchdiffeq

Collecting torchdiffeq
  Downloading torchdiffeq-0.2.5-py3-none-any.whl.metadata (440 bytes)
Downloading torchdiffeq-0.2.5-py3-none-any.whl (32 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.5


In [2]:
!pip install torchcfm

Collecting torchcfm
  Downloading torchcfm-1.0.5-py3-none-any.whl.metadata (14 kB)
Collecting lightning-bolts (from torchcfm)
  Downloading lightning_bolts-0.7.0-py3-none-any.whl.metadata (9.5 kB)
Collecting scprep (from torchcfm)
  Downloading scprep-1.2.3-py3-none-any.whl.metadata (7.0 kB)
Collecting scanpy (from torchcfm)
  Downloading scanpy-1.11.0-py3-none-any.whl.metadata (9.5 kB)
Collecting torchdyn (from torchcfm)
  Downloading torchdyn-1.0.6-py3-none-any.whl.metadata (891 bytes)
Collecting pot (from torchcfm)
  Downloading POT-0.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (34 kB)
Collecting clean-fid (from torchcfm)
  Downloading clean_fid-0.1.35-py3-none-any.whl.metadata (36 kB)
Collecting pytorch-lightning<2.0.0,>1.7.0 (from lightning-bolts->torchcfm)
  Downloading pytorch_lightning-1.9.5-py3-none-any.whl.metadata (23 kB)
Collecting anndata>=0.8 (from scanpy->torchcfm)
  Downloading anndata-0.11.3-py3-none-any.whl.metadata (8.2 kB)
Collecting lega

In [3]:
import os
import sys
import copy
import torch
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import trange
from torchdiffeq import odeint
from torchdyn.core import NeuralODE
from torchvision import transforms
from torchvision.utils import save_image, make_grid

# Import the UNet wrapper and conditional flow matching classes
from torchcfm.models.unet.unet import UNetModelWrapper
from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)

In [5]:
# ----------------- Global Variables -----------------
# Paths to the HAM10000 dataset (adjust as needed)
CSV_PATH = "/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_metadata.csv"
FOLDER1 = "/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1"
FOLDER2 = "/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_2"

#PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Image and model parameters
IMAGE_SIZE = 128  # Increase resolution if desired
NUM_CHANNEL = 128
MODEL_TYPE = "otcfm"  # Options: "otcfm", "icfm", "fm", "si"

# Training parameters
LR = 2e-4
GRAD_CLIP = 1.0
TOTAL_STEPS = 400001
WARMUP = 5000
BATCH_SIZE = 8
NUM_WORKERS = 4
EMA_DECAY = 0.9999
SAVE_STEP = 20000

# Evaluation / integration parameters
INTEGRATION_STEPS = 100
INTEGRATION_METHOD = "dopri5"  # Use "euler" to use NeuralODE wrapper instead of odeint
TOL = 1e-5
BATCH_SIZE_FID = 64
NUM_GEN = 5000  # Number of images to generate for FID evaluation

# Action: set to "train" to run training, or "fid" to run FID evaluation
ACTION = "train"  # or "fid"

# Option to use a small subset for quick experiments
USE_SMALL_SUBSET = True
if USE_SMALL_SUBSET:
    TOTAL_STEPS = 5001        # Fewer training steps
    SAVE_STEP = 1000          # Save checkpoints more frequently for testing
    NUM_GEN = 50              # Generate only a few images

# ----------------- Device -----------------
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [6]:






# ----------------- Custom Dataset for HAM10000 -----------------
class Ham10000Dataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, folder1, folder2, transform=None):
        self.df = pd.read_csv(csv_path)
        # Map lesion types to binary labels: 0 for benign, 1 for malignant
        benign_types = ['nv', 'bkl', 'akiec', 'vasc', 'df']
        malignant_types = ['mel', 'bcc']
        self.df['target'] = self.df['dx'].apply(lambda x: 0 if x in benign_types else 1)
        self.folder1 = folder1
        self.folder2 = folder2
        self.transform = transform
        # Compute image paths
        self.df['image_path'] = self.df['image_id'].apply(self.get_image_path)
        # Remove rows with missing images
        self.df = self.df[self.df['image_path'].notnull()].reset_index(drop=True)
        
        # For quick experiments, optionally use only a small subset of the data.
        if USE_SMALL_SUBSET:
            subset_size = min(1000, len(self.df))
            self.df = self.df.iloc[:subset_size].reset_index(drop=True)

    def get_image_path(self, image_id):
        filename = f"{image_id}.jpg"
        path1 = os.path.join(self.folder1, filename)
        path2 = os.path.join(self.folder2, filename)
        if os.path.exists(path1):
            return path1
        elif os.path.exists(path2):
            return path2
        else:
            return None

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['image_path']).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        label = torch.tensor(row['target'], dtype=torch.long)
        return img, label

# ----------------- Data Transforms -----------------
transform_train = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_fid = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# ----------------- Model Setup -----------------
net_model = UNetModelWrapper(
    dim=(3, IMAGE_SIZE, IMAGE_SIZE),
    num_res_blocks=2,
    num_channels=NUM_CHANNEL,
    channel_mult=[1, 2, 2, 2],
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="16",
    dropout=0.1,
).to(device)

ema_model = copy.deepcopy(net_model)

# ----------------- Define Flow Matcher -----------------
sigma = 0.0
if MODEL_TYPE == "otcfm":
    FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
elif MODEL_TYPE == "icfm":
    FM = ConditionalFlowMatcher(sigma=sigma)
elif MODEL_TYPE == "fm":
    FM = TargetConditionalFlowMatcher(sigma=sigma)
elif MODEL_TYPE == "si":
    FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
else:
    raise NotImplementedError(f"Unknown model type {MODEL_TYPE}")

# ----------------- Utility Functions -----------------
def ema_update(source, target, decay):
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for key in source_dict.keys():
        target_dict[key].data.copy_(target_dict[key].data * decay + source_dict[key].data * (1 - decay))

def infiniteloop(dataloader):
    while True:
        for data in dataloader:
            yield data

def generate_samples(model, savedir, step):
    model.eval()
    model_ = copy.deepcopy(model)
    if hasattr(model_, "module"):
        model_ = model_.module.to(device)
    # Use NeuralODE with Euler if desired
    node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint")
    with torch.no_grad():
        noise = torch.randn(64, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)
        # Generate random binary conditions for each sample
        t_span = torch.linspace(0, 1, 100, device=device)
        # Assumes the model forward supports a "condition" argument.
        traj = node_.trajectory(noise, t_span=t_span)
        traj = traj[-1, :].view([-1, 3, IMAGE_SIZE, IMAGE_SIZE]).clamp(-1, 1)
        traj = traj / 2 + 0.5
    os.makedirs(savedir, exist_ok=True)
    save_image(traj, os.path.join(savedir, f"generated_FM_images_step_{step}.png"), nrow=8)
    model.train()

# ----------------- Training Function -----------------
def train():
    print("Starting training...")
    print("Learning rate:", LR)
    print("Total steps:", TOTAL_STEPS)
    print("EMA decay:", EMA_DECAY)
    print("Save step:", SAVE_STEP)
    
    # Create the HAM10000 dataset and use an 80% train split
    dataset_full = Ham10000Dataset(CSV_PATH, FOLDER1, FOLDER2, transform=transform_train)
    split_idx = int(0.8 * len(dataset_full))
    dataset_train = torch.utils.data.Subset(dataset_full, list(range(split_idx)))
    
    dataloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        drop_last=True,
    )
    datalooper = infiniteloop(dataloader)
    
    optim = torch.optim.Adam(net_model.parameters(), lr=LR)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda step: min(step, WARMUP) / WARMUP)
    
    # Display model size
    model_size = sum(p.data.nelement() for p in net_model.parameters())
    print("Model parameters: %.2f M" % (model_size / 1e6))

    pbar = trange(TOTAL_STEPS, desc="Training", ncols=80)

    # Training loop
    for step in pbar:
        optim.zero_grad()
        # Get both image and label (used as the condition)
        x1, cond = next(datalooper)
        x1 = x1.to(device)
        cond = cond.to(device)
        # Create noise input with same shape as x1
        x0 = torch.randn_like(x1)
        # Get time, intermediate state, and target flow
        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
        # Forward pass (assumes the model supports a "condition" argument)
        vt = net_model(t, xt)
        loss = torch.mean((vt - ut) ** 2)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net_model.parameters(), GRAD_CLIP)
        optim.step()
        sched.step()
        ema_update(net_model, ema_model, EMA_DECAY)
        
        torch.cuda.empty_cache()

        
        if step % 100 == 0:
            trange_desc = f"Step {step} Loss {loss.item():.4f}"
            pbar.write(trange_desc)
        
        # Save checkpoint and generated samples at intervals
        if SAVE_STEP > 0 and step % SAVE_STEP == 0:
            save_dir = os.path.join("./results/", MODEL_TYPE)
            os.makedirs(save_dir, exist_ok=True)
            generate_samples(net_model, save_dir, step)
            generate_samples(ema_model, save_dir, step)
            torch.save(
                {
                    "net_model": net_model.state_dict(),
                    "ema_model": ema_model.state_dict(),
                    "sched": sched.state_dict(),
                    "optim": optim.state_dict(),
                    "step": step,
                },
                os.path.join(save_dir, f"{MODEL_TYPE}_ham10000_weights_step_{step}.pt"),
            )

# ----------------- FID Evaluation Function -----------------
def fid_evaluation():
    print("Starting FID evaluation...")
    # Load a checkpoint from the last training step (or adjust as needed)
    ckpt_path = os.path.join("./results/", MODEL_TYPE, f"{MODEL_TYPE}_ham10000_weights_step_{TOTAL_STEPS - 1}.pt")
    print("Loading checkpoint from:", ckpt_path)
    checkpoint = torch.load(ckpt_path, map_location=device)
    state_dict = checkpoint["ema_model"]
    try:
        net_model.load_state_dict(state_dict)
    except RuntimeError:
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            new_state_dict[k[7:]] = v
        net_model.load_state_dict(new_state_dict)
    net_model.eval()

    # Define integration for FID generation
    if INTEGRATION_METHOD == "euler":
        node = NeuralODE(net_model, solver=INTEGRATION_METHOD)
    else:
        node = net_model  # We'll wrap the call in odeint
    
    def gen_1_img(unused_latent):
        with torch.no_grad():
            x = torch.randn(BATCH_SIZE_FID, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)
            cond = torch.randint(0, 2, (BATCH_SIZE_FID,), device=device)
            if INTEGRATION_METHOD == "euler":
                t_span = torch.linspace(0, 1, INTEGRATION_STEPS + 1, device=device)
                traj = node.trajectory(x, t_span=t_span, condition=cond)
            else:
                t_span = torch.linspace(0, 1, 2, device=device)
                traj = odeint(lambda t, x: net_model(t, x, condition=cond),
                              x, t_span, rtol=TOL, atol=TOL, method=INTEGRATION_METHOD)
            traj = traj[-1, :]
            img = (traj * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            return img

    # In practice, you would compare generated images to a held-out set.
    # Here we simulate FID evaluation by generating a few images.
    print("Generating images for FID evaluation...")
    imgs = gen_1_img(None)
    print("Generated image batch shape:", imgs.shape)
    # (Optionally, compute FID using a suitable library or custom implementation.)
    # For demonstration, we simply print a simulated FID score.
    fid_score = np.random.uniform(10, 50)
    print("Simulated FID score:", fid_score)

# ----------------- Main -----------------
def main():
    if ACTION == "train":
        train()
    elif ACTION == "fid":
        fid_evaluation()
    else:
        print("Unknown ACTION. Set ACTION to 'train' or 'fid'.")

if __name__ == "__main__":
    main()


Starting training...
Learning rate: 0.0002
Total steps: 5001
EMA decay: 0.9999
Save step: 1000
Model parameters: 35.75 M


Training:   0%|                                        | 0/5001 [00:02<?, ?it/s]

Step 0 Loss 1.1134


Training:   2%|▌                             | 101/5001 [04:09<53:46,  1.52it/s]

Step 100 Loss 1.1154


Training:   4%|█▏                            | 201/5001 [05:05<51:51,  1.54it/s]

Step 200 Loss 0.9233


Training:   6%|█▊                            | 301/5001 [06:02<50:58,  1.54it/s]

Step 300 Loss 0.7322


Training:   8%|██▍                           | 401/5001 [06:59<49:58,  1.53it/s]

Step 400 Loss 0.4720


Training:  10%|███                           | 501/5001 [07:55<49:48,  1.51it/s]

Step 500 Loss 0.3424


Training:  12%|███▌                          | 601/5001 [08:52<47:57,  1.53it/s]

Step 600 Loss 0.1017


Training:  14%|████▏                         | 701/5001 [09:48<46:40,  1.54it/s]

Step 700 Loss 0.1654


Training:  16%|████▊                         | 801/5001 [10:45<45:42,  1.53it/s]

Step 800 Loss 0.0605


Training:  18%|█████▍                        | 901/5001 [11:41<44:38,  1.53it/s]

Step 900 Loss 0.0786


Training:  20%|█████▊                       | 1000/5001 [12:38<37:29,  1.78it/s]

Step 1000 Loss 0.2039


Training:  22%|██████▍                      | 1101/5001 [16:42<40:55,  1.59it/s]

Step 1100 Loss 0.0529


Training:  24%|██████▉                      | 1201/5001 [17:35<39:42,  1.60it/s]

Step 1200 Loss 0.1178


Training:  26%|███████▌                     | 1301/5001 [18:29<38:25,  1.60it/s]

Step 1300 Loss 0.0841


Training:  28%|████████                     | 1401/5001 [19:23<38:27,  1.56it/s]

Step 1400 Loss 0.0440


Training:  30%|████████▋                    | 1501/5001 [20:17<37:11,  1.57it/s]

Step 1500 Loss 0.0418


Training:  32%|█████████▎                   | 1601/5001 [21:11<35:15,  1.61it/s]

Step 1600 Loss 0.0358


Training:  34%|█████████▊                   | 1701/5001 [22:05<34:35,  1.59it/s]

Step 1700 Loss 0.0471


Training:  36%|██████████▍                  | 1801/5001 [22:59<33:20,  1.60it/s]

Step 1800 Loss 0.0325


Training:  38%|███████████                  | 1901/5001 [23:52<32:36,  1.58it/s]

Step 1900 Loss 0.1384


Training:  40%|███████████▌                 | 2000/5001 [24:46<26:43,  1.87it/s]

Step 2000 Loss 0.0362


Training:  42%|████████████▏                | 2101/5001 [28:51<30:54,  1.56it/s]

Step 2100 Loss 0.0768


Training:  44%|████████████▊                | 2201/5001 [29:46<29:40,  1.57it/s]

Step 2200 Loss 0.0654


Training:  46%|█████████████▎               | 2301/5001 [30:41<28:32,  1.58it/s]

Step 2300 Loss 0.0398


Training:  48%|█████████████▉               | 2401/5001 [31:36<27:32,  1.57it/s]

Step 2400 Loss 0.0726


Training:  50%|██████████████▌              | 2501/5001 [32:31<26:34,  1.57it/s]

Step 2500 Loss 0.0479


Training:  52%|███████████████              | 2601/5001 [33:26<25:26,  1.57it/s]

Step 2600 Loss 0.1479


Training:  54%|███████████████▋             | 2701/5001 [34:20<24:25,  1.57it/s]

Step 2700 Loss 0.0757


Training:  56%|████████████████▏            | 2801/5001 [35:15<24:19,  1.51it/s]

Step 2800 Loss 0.1063


Training:  58%|████████████████▊            | 2901/5001 [36:10<22:15,  1.57it/s]

Step 2900 Loss 0.0832


Training:  60%|█████████████████▍           | 3000/5001 [37:05<18:11,  1.83it/s]

Step 3000 Loss 0.0340


Training:  62%|█████████████████▉           | 3101/5001 [41:19<23:10,  1.37it/s]

Step 3100 Loss 0.0521


Training:  64%|██████████████████▌          | 3201/5001 [42:24<22:01,  1.36it/s]

Step 3200 Loss 0.0461


Training:  66%|███████████████████▏         | 3301/5001 [43:28<20:42,  1.37it/s]

Step 3300 Loss 0.0756


Training:  68%|███████████████████▋         | 3401/5001 [44:32<19:20,  1.38it/s]

Step 3400 Loss 0.0421


Training:  70%|████████████████████▎        | 3501/5001 [45:36<18:12,  1.37it/s]

Step 3500 Loss 0.0242


Training:  72%|████████████████████▉        | 3601/5001 [46:41<17:05,  1.36it/s]

Step 3600 Loss 0.0382


Training:  74%|█████████████████████▍       | 3701/5001 [47:45<15:47,  1.37it/s]

Step 3700 Loss 0.0703


Training:  76%|██████████████████████       | 3801/5001 [48:49<14:42,  1.36it/s]

Step 3800 Loss 0.0610


Training:  78%|██████████████████████▌      | 3901/5001 [49:53<13:23,  1.37it/s]

Step 3900 Loss 0.0396


Training:  80%|███████████████████████▏     | 4000/5001 [50:57<10:39,  1.57it/s]

Step 4000 Loss 0.0769


Training:  82%|███████████████████████▊     | 4101/5001 [55:11<11:04,  1.35it/s]

Step 4100 Loss 0.2542


Training:  84%|████████████████████████▎    | 4201/5001 [56:15<09:49,  1.36it/s]

Step 4200 Loss 0.1028


Training:  86%|████████████████████████▉    | 4301/5001 [57:19<08:32,  1.36it/s]

Step 4300 Loss 0.0374


Training:  88%|█████████████████████████▌   | 4401/5001 [58:23<07:15,  1.38it/s]

Step 4400 Loss 0.0304


Training:  90%|██████████████████████████   | 4501/5001 [59:28<06:04,  1.37it/s]

Step 4500 Loss 0.0538


Training:  92%|████████████████████████▊  | 4601/5001 [1:00:32<04:49,  1.38it/s]

Step 4600 Loss 0.0261


Training:  94%|█████████████████████████▍ | 4701/5001 [1:01:36<03:37,  1.38it/s]

Step 4700 Loss 0.0347


Training:  96%|█████████████████████████▉ | 4801/5001 [1:02:40<02:25,  1.38it/s]

Step 4800 Loss 0.0454


Training:  98%|██████████████████████████▍| 4901/5001 [1:03:44<01:12,  1.38it/s]

Step 4900 Loss 0.1397


Training: 100%|██████████████████████████▉| 5000/5001 [1:04:48<00:00,  1.57it/s]

Step 5000 Loss 0.0552


Training: 100%|███████████████████████████| 5001/5001 [1:07:58<00:00,  1.23it/s]


In [7]:
# Load the saved EMA model checkpoint
ckpt_path = "./results/otcfm/otcfm_ham10000_weights_step_5000.pt"
checkpoint = torch.load(ckpt_path, map_location=device)
state_dict = checkpoint["ema_model"]
net_model.load_state_dict(state_dict)
net_model.eval()

# Generate images and save them (the function generates a grid image and writes it to disk)
generate_samples(net_model, "./results/otcfm", 5000)


  checkpoint = torch.load(ckpt_path, map_location=device)


In [8]:
from cleanfid import fid

def gen_images_for_fid(unused):
    with torch.no_grad():
        x = torch.randn(BATCH_SIZE_FID, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)
        # Generate images unconditionally (remove 'condition' if not used)
        if INTEGRATION_METHOD == "euler":
            t_span = torch.linspace(0, 1, INTEGRATION_STEPS + 1, device=device)
            traj = net_model.trajectory(x, t_span=t_span)
        else:
            t_span = torch.linspace(0, 1, 2, device=device)
            traj = odeint(lambda t, x: net_model(t, x), x, t_span, rtol=TOL, atol=TOL, method=INTEGRATION_METHOD)
        traj = traj[-1, :]
        return (traj * 127.5 + 128).clamp(0, 255).to(torch.uint8)

fid_score = fid.compute_fid(
    gen=gen_images_for_fid,
    dataset_name="ham10000",  # You'd need to set this up to point to your reference data.
    batch_size=BATCH_SIZE_FID,
    dataset_res=IMAGE_SIZE,
    num_gen=NUM_GEN,
    dataset_split="train",  # Adjust as needed.
    mode="legacy_tensorflow",
)
print("FID:", fid_score)


compute FID of a model with ham10000-128 statistics
downloading statistics to /usr/local/lib/python3.10/dist-packages/cleanfid/stats/ham10000_legacy_tensorflow_train_128.npz


HTTPError: HTTP Error 404: Not Found