# function to make text + image dataset

In [3]:
import json
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset

class PDETensorTextDataset(Dataset):
    def __init__(self, jsonl_path, dtype=torch.float32):
        self.jsonl_path = Path(jsonl_path)
        self.root = self.jsonl_path.parent
        self.dtype = dtype

        with open(self.jsonl_path, "r") as f:
            self.samples = [json.loads(l) for l in f]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        tensor_path = self.root / sample["tensor"]
        tensor = np.load(tensor_path)

        tensor = torch.from_numpy(tensor).to(self.dtype)
        text = sample["text"]

        return tensor, text

In [None]:
dataset = PDETensorTextDataset("/Users/divyam/Course/Project Arbeit/pde_solver/src/dataset/annotations.jsonl")

sol, txt = dataset[0]
print(sol.shape, sol.dtype)
print(txt)


torch.Size([128, 384]) torch.float32
This model captures interactions between kinetic energy, pressure, and radiative cooling.


# Vision Only CLIP

In [5]:
from torch.utils.data import Dataset
class VisionOnlyPDEDataset(Dataset):
    def __init__(self, base_dataset, input_steps=5, output_steps=10):
        self.base = base_dataset
        self.input_steps = input_steps
        self.output_steps = output_steps

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        solution, _ = self.base[idx]   # [T, X]

        x = solution[:self.input_steps]                         # [5, X]
        y = solution[self.input_steps:self.input_steps+10]      # [10, X]

        # Add spatial height dimension for Conv2d
        x = x.unsqueeze(1)   # [5, 1, X]
        y = y.unsqueeze(1)   # [10, 1, X]

        return x.float(), y.float()


In [6]:
import torch.nn as nn
import torch

class CNNPDEBaseline(nn.Module):
    def __init__(self, input_steps, output_steps):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(input_steps, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        self.decoder = nn.Conv2d(
            128, output_steps, kernel_size=3, padding=1
        )

    def forward(self, x):
        # x: [B, input_steps, 1, X]
        z = self.encoder(x)
        out = self.decoder(z)   # [B, output_steps, 1, X]
        return out



In [7]:
def train_pde_baseline(model, dataloader, optimizer, device, epochs=20):
    model.train()
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        total_loss = 0.0

        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            preds = model(x)
            loss = criterion(preds, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} | Loss {total_loss/len(dataloader):.6f}")


In [None]:
from torch.utils.data import DataLoader

base_dataset = PDETensorTextDataset(
    "/Users/divyam/Course/Project Arbeit/pde_solver/src/dataset/annotations.jsonl"
)

vision_dataset = VisionOnlyPDEDataset(base_dataset, input_steps=5, output_steps=10)

vision_loader = DataLoader(vision_dataset, batch_size=16, shuffle=True)

device = "mps"

vision_model = CNNPDEBaseline(input_steps=5, output_steps=10).to(device)

optimizer = torch.optim.AdamW(vision_model.parameters(), lr=3e-4)

train_pde_baseline(
    vision_model,
    vision_loader,
    optimizer,
    device=device,
    epochs=20
)

Epoch 1 | Loss 1179.096507
Epoch 2 | Loss 132.490448
Epoch 3 | Loss 111.409473
Epoch 4 | Loss 101.605314
Epoch 5 | Loss 96.595752
Epoch 6 | Loss 92.406907
Epoch 7 | Loss 90.204022
Epoch 8 | Loss 88.195630
Epoch 9 | Loss 87.203460
Epoch 10 | Loss 86.246500
Epoch 11 | Loss 85.383471
Epoch 12 | Loss 84.735451
Epoch 13 | Loss 83.920526
Epoch 14 | Loss 83.631067
Epoch 15 | Loss 83.377690
Epoch 16 | Loss 82.467060
Epoch 17 | Loss 82.252487
Epoch 18 | Loss 81.550521
Epoch 19 | Loss 81.551677
Epoch 20 | Loss 81.077893



# Vision Text CLIP

In [9]:
class PDEEncoder(nn.Module):
    def __init__(self, input_steps, embed_dim):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(input_steps, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )

        self.proj = nn.Linear(128, embed_dim)

    def forward(self, x):
        # x: [B, input_steps, X]
        z = self.encoder(x).squeeze(-1)  # [B, 128]
        return F.normalize(self.proj(z), dim=-1)


In [10]:
import torch.nn.functional as F

class PDETextCLIP(nn.Module):
    def __init__(self, input_steps, text_dim, embed_dim=512):
        super().__init__()

        self.vision_encoder = PDEEncoder(input_steps, embed_dim)
        self.text_proj = nn.Linear(text_dim, embed_dim)

    def forward(self, pde_init, text_emb):
        img_emb = self.vision_encoder(pde_init)
        txt_emb = F.normalize(self.text_proj(text_emb), dim=-1)
        return img_emb, txt_emb
    

def clip_loss(img_emb, txt_emb, temperature=0.07):
    logits = img_emb @ txt_emb.T / temperature
    labels = torch.arange(len(img_emb)).to(img_emb.device)

    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2


In [11]:
class PDECLIPDataset(Dataset):
    def __init__(self, base_dataset, input_steps=5, output_steps=10):
        self.base = base_dataset
        self.input_steps = input_steps
        self.output_steps = output_steps

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        solution, text = self.base[idx]

        if isinstance(solution, (tuple, list)):
            solution = solution[0]

        x = solution[:self.input_steps]   # [5, X]
        y = solution[self.input_steps:self.input_steps+10]

        return x.float(), y.float(), text 

In [12]:
def train_clip(
    model,
    dataloader,
    text_encoder,
    optimizer,
    device,
    epochs=20
):
    model.train()

    for epoch in range(epochs):
        total_loss = 0.0

        for x, y, texts in dataloader:
            x = x.to(device)  # [B, input_steps, X]
            y = y.to(device)  # [B, output_steps, X]

            with torch.no_grad():
                text_emb = torch.from_numpy(
                    text_encoder.encode(texts)
                ).to(device)

            img_emb, txt_emb = model(x, text_emb)
            loss = clip_loss(img_emb, txt_emb)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} | CLIP Loss {total_loss/len(dataloader):.4f}")


In [None]:
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader

device = "mps"

text_encoder = SentenceTransformer(
    "sentence-transformers/all-MiniLM-L6-v2",
    device=device
)

base_dataset = PDETensorTextDataset(
    "/Users/divyam/Course/Project Arbeit/pde_solver/src/dataset/annotations.jsonl"
)

clip_dataset = PDECLIPDataset(
    base_dataset,
    input_steps=5,
    output_steps=10
)

loader = DataLoader(
    clip_dataset,
    batch_size=32,
    shuffle=True
)

vision_text_solver_model = PDETextCLIP(
    input_steps=5,
    text_dim=384,
    embed_dim=512
).to(device)

optimizer = torch.optim.AdamW(
    vision_text_solver_model.parameters(),
    lr=3e-4
)

train_clip(
    vision_text_solver_model,
    loader,
    text_encoder,
    optimizer,
    device=device,
    epochs=20
)


Epoch 1 | CLIP Loss 3.4435
Epoch 2 | CLIP Loss 3.4337
Epoch 3 | CLIP Loss 3.4300
Epoch 4 | CLIP Loss 3.4258
Epoch 5 | CLIP Loss 3.4291
Epoch 6 | CLIP Loss 3.4254
Epoch 7 | CLIP Loss 3.4132
Epoch 8 | CLIP Loss 3.4076
Epoch 9 | CLIP Loss 3.3994
Epoch 10 | CLIP Loss 3.3952
Epoch 11 | CLIP Loss 3.3869
Epoch 12 | CLIP Loss 3.3733
Epoch 13 | CLIP Loss 3.3514
Epoch 14 | CLIP Loss 3.3483
Epoch 15 | CLIP Loss 3.3316
Epoch 16 | CLIP Loss 3.3188
Epoch 17 | CLIP Loss 3.2947
Epoch 18 | CLIP Loss 3.2550
Epoch 19 | CLIP Loss 3.2281
Epoch 20 | CLIP Loss 3.2057


# Vision + Shuffled text

In [14]:
import random
from torch.utils.data import Dataset

class ShuffledTextPDECLIPDataset(Dataset):
    def __init__(self, base_dataset, input_steps=5, output_steps=10):
        self.base = base_dataset
        self.input_steps = input_steps
        self.output_steps = output_steps
        self.texts = [s["text"] for s in base_dataset.samples]

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        solution, _ = self.base[idx]

        # correct PDE initial condition
        x = solution[:self.input_steps].float()  # [5, X]
        y = solution[self.input_steps:self.input_steps+10].float()  # [10, X]

        # shuffled (incorrect) text
        random_text = random.choice(self.texts)

        return x, y, random_text

In [15]:
from torch.utils.data import DataLoader

shuffled_dataset = ShuffledTextPDECLIPDataset(
    base_dataset,
    input_steps=5,
    output_steps=10
)

loader = DataLoader(
    shuffled_dataset,
    batch_size=32,
    shuffle=True
)

train_clip(
    vision_text_solver_model,
    loader,
    text_encoder,
    optimizer,
    device="mps",
    epochs=20
)


Epoch 1 | CLIP Loss 3.4977
Epoch 2 | CLIP Loss 3.4385
Epoch 3 | CLIP Loss 3.4359
Epoch 4 | CLIP Loss 3.4338
Epoch 5 | CLIP Loss 3.4371
Epoch 6 | CLIP Loss 3.4360
Epoch 7 | CLIP Loss 3.4326
Epoch 8 | CLIP Loss 3.4358
Epoch 9 | CLIP Loss 3.4363
Epoch 10 | CLIP Loss 3.4354
Epoch 11 | CLIP Loss 3.4354
Epoch 12 | CLIP Loss 3.4350
Epoch 13 | CLIP Loss 3.4346
Epoch 14 | CLIP Loss 3.4355
Epoch 15 | CLIP Loss 3.4343
Epoch 16 | CLIP Loss 3.4355
Epoch 17 | CLIP Loss 3.4355
Epoch 18 | CLIP Loss 3.4352
Epoch 19 | CLIP Loss 3.4336
Epoch 20 | CLIP Loss 3.4353


# Testing

In [None]:
from torch.utils.data import DataLoader

test_base_dataset = PDETensorTextDataset(
    "/Users/divyam/Course/Project Arbeit/pde_solver/src/dataset/annotations_test.jsonl"
)


In [17]:
test_vision_dataset = VisionOnlyPDEDataset(
    test_base_dataset,
    input_steps=5,
    output_steps=10
)

test_vision_loader = DataLoader(
    test_vision_dataset,
    batch_size=16,
    shuffle=False   
)


In [18]:
test_clip_dataset = PDECLIPDataset(
    test_base_dataset,
    input_steps=5
)

test_clip_loader = DataLoader(
    test_clip_dataset,
    batch_size=32,
    shuffle=False   # CRITICAL
)


In [19]:
import random
random.seed(42)

test_shuffled_clip_dataset = ShuffledTextPDECLIPDataset(
    test_base_dataset,
    input_steps=5
)

test_shuffled_clip_loader = DataLoader(
    test_shuffled_clip_dataset,
    batch_size=32,
    shuffle=False   
)


In [20]:
class VisionTextPDESolver(nn.Module):
    def __init__(self, vision_solver, text_dim, input_steps):
        super().__init__()
        self.vision_solver = vision_solver
        self.text_proj = nn.Linear(text_dim, input_steps)

    def forward(self, x, text_emb):
        """
        x: [B, input_steps, X]  OR  [B, input_steps, 1, X]
        text_emb: [B, text_dim]
        """

        if x.dim() == 3:
            x = x.unsqueeze(2)   # [B, input_steps, 1, X]

        # text conditioning
        cond = self.text_proj(text_emb)          # [B, input_steps]
        cond = cond.unsqueeze(-1).unsqueeze(-1)  # [B, input_steps, 1, 1]

        x_cond = x + cond                        # broadcast over X

        preds = self.vision_solver(x_cond)       # [B, output_steps, 1, X]

        return preds


In [21]:
import torch.nn.functional as F

@torch.no_grad()
def evaluate_pde_solver(model, dataloader, device):
    model.eval()
    mse = 0.0
    n = 0

    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        preds = model(x)
        mse += F.mse_loss(preds, y, reduction="sum").item()
        n += y.numel()

    return mse / n


In [22]:
@torch.no_grad()
def evaluate_vision_text_solver(
    model,
    dataloader,
    text_encoder,
    device
):
    model.eval()
    mse = 0.0
    n = 0

    for x, y, texts in dataloader:
        x = x.to(device)
        y = y.to(device)

        if y.dim() == 3:
            y = y.unsqueeze(2) 

        text_emb = torch.from_numpy(
            text_encoder.encode(list(texts))
        ).to(device)

        preds = model(x, text_emb)

        mse += F.mse_loss(preds, y, reduction="sum").item()
        n += y.numel()

    return mse / n


In [23]:
test_mse = evaluate_pde_solver(
    vision_model,
    test_vision_loader,
    device
)
print("Vision-only OOD MSE:", test_mse)


Vision-only OOD MSE: 79.23567497049184


In [24]:
test_vt_dataset = PDECLIPDataset(
    test_base_dataset,
    input_steps=5,
    output_steps=10
)

test_vt_loader = DataLoader(
    test_vt_dataset,
    batch_size=16,
    shuffle=False
)


In [25]:
vision_text_solver_model = VisionTextPDESolver(
    vision_solver=vision_model,   
    text_dim=384,
    input_steps=5
).to(device)


In [26]:
test_mse = evaluate_vision_text_solver(
    vision_text_solver_model,
    test_vt_loader,
    text_encoder,
    device
)

print("Vision + text OOD MSE:", test_mse)

Vision + text OOD MSE: 79.38802373401403


In [27]:
from torch.utils.data import DataLoader

test_shuffled_dataset = ShuffledTextPDECLIPDataset(
    test_base_dataset,      # annotations_test.jsonl
    input_steps=5,
    output_steps=10
)

test_shuffled_loader = DataLoader(
    test_shuffled_dataset,
    batch_size=16,
    shuffle=False
)


In [28]:
shuffled_mse = evaluate_vision_text_solver(
    vision_text_solver_model,
    test_shuffled_loader,
    text_encoder,
    device
)

print("Vision + shuffled text OOD MSE:", shuffled_mse)


Vision + shuffled text OOD MSE: 79.3934084000848
