In [1]:
import sys
from __future__ import annotations

In [2]:
from pathlib import Path

In [3]:
data_dir = "../benchmark/"

In [4]:
from src.model_clay import CLAYModule
import src.datamodule
from src.datamodule_eval_local import ClayDataset, ClayDataModule
import pandas as pd
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import numpy as np
import einops
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
import rasterio as rio
from einops import rearrange, reduce
import torch

In [5]:
model = CLAYModule.load_from_checkpoint("../mae_epoch-10_val-loss-0.563.ckpt", mask_ratio=0.)
model.eval();

In [6]:
dm = ClayDataModule(data_dir=data_dir, batch_size=4)

In [9]:
dm.setup()

Total number of chips: 19


In [10]:
val_dl = iter(dm.val_dataloader())

In [11]:
batch = next(val_dl)

In [12]:
batch

{'labels': tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
           [1., 0., 0.,  ..., 0., 0., 0.],
           [1., 1., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]]]),
 'pixels': tensor([[[[-0.4157, -0.4420, -0.4249,  ..., -0.3267, -0.3222, -0.3279],
           [-0.3632, -0.4055, -0.4180,  ..., -0.3290, -0.3096, -0.3267],
           [-0.3804, -0.3621, -0.4009,  ..., -0.2994, -0.2560, -0.2720],
           ...,
           [-0.2058, -0.2891, -0.0358,  ..., -0.4545, -0.4363, -0.4271],
           [ 0.0030, -0.2857, -0.1864,  ..., -0.4420, -0.4465, -0.4340],
           [ 0.3932, -0.2663, -0.2378,  ..., -0.4488, -0.4557, -0.4465]],
 
          [[-0.3860, -0.4002, -0.3980,  ..., -0.2684, -0.2825, -0.2880],
           [-0.3697, -0.3947, -0.3893,  ..., -0.2597, -0.3010, -0.2956],
           [-0.3413, -0.3435, -0.3446,  ..., -0.2052, -0.2858, -0.2684],
           ...,
          

In [13]:
batch["pixels"] = batch["pixels"].to(model.device)
batch["timestep"] = batch["timestep"].to(model.device)
batch["latlon"] = batch["latlon"].to(model.device)
emb = model.model.encoder(batch)

In [14]:
emb

(tensor([[[-0.1726,  0.0199,  0.0531,  ...,  0.0189,  0.0039, -0.2733],
          [-0.5136, -0.2629, -0.0092,  ...,  0.0840, -0.0267,  0.0048],
          [-0.1802, -0.0342,  0.0473,  ...,  0.0048, -0.0149, -0.1933],
          ...,
          [-0.0627,  0.0533,  0.0892,  ...,  0.0223,  0.0055, -0.3010],
          [ 0.0280,  0.0501,  0.0121,  ...,  0.0184,  0.0125, -0.2738],
          [-0.3056, -0.0306,  0.1429,  ...,  0.0135, -0.0400, -0.1876]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[  90, 1449,  709,  ..., 1151,  329,  756]]),
 tensor([], size=(1, 0), dtype=torch.int64),
 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]))

In [20]:
emb[0].shape

torch.Size([1, 1538, 768])

In [48]:
from pytorch_lightning import LightningModule, Trainer

class UNet(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()        
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(4608, 64, kernel_size=(1,1), stride=(1,1)),
            torch.nn.ConvTranspose2d(64, 128, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.ConvTranspose2d(128, 512, kernel_size=3, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(512, 1, kernel_size=3, padding=1),
            torch.nn.Upsample(scale_factor=16),
            #torch.nn.Conv2d(512, 1, kernel_size=3, padding=1),
            torch.nn.Softmax(dim=1),
            )

    def forward(self, x):
        batch = x
        batch["pixels"] = batch["pixels"].to(model_clay.device)
        batch["timestep"] = batch["timestep"].to(model_clay.device)
        batch["latlon"] = batch["latlon"].to(model_clay.device)
        emb = model_clay.model.encoder(batch)
        embeddings = emb[0] #space["embeddings"]
        embeddings = embeddings[:,:-2,:]
        latent = rearrange(embeddings, "b (g l) d  -> b g l d", g=6)
        latent = rearrange(latent, "b g (h w) d -> b g h w d", h=16, w=16)
        latent = rearrange(latent, "b g h w d -> b (g d) h w")
        x = self.decoder(latent)
        return x


class SegmentationModel(LightningModule):
    def __init__(self, model, datamodule):
        super().__init__()
        self.model = model
        self.datamodule = datamodule
    
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, x_, y = batch, batch["pixels"], batch["labels"]
        # x = torch.tensor(x, requires_grad=True)
        #y = y.to(dtype=torch.float32)
        #y = y.squeeze()
        #print("Shapes - x:", x.shape, "y:", y.shape)
        y = y[0, :, :, :]
        y_pred = self.model(x)
        _, prediction = torch.max(y_pred, dim=1)
        print("Prediction shape:", prediction.shape)
        print("Label shape:", y.shape)
        loss = torch.nn.functional.cross_entropy(prediction.to(dtype=torch.float32), y)
        loss = torch.tensor(loss, requires_grad=True)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, x_, y = batch, batch["pixels"], batch["labels"]
        # x = torch.tensor(x, requires_grad=True)
        #y = y.to(dtype=torch.float32)
        #y = y.squeeze()
        #print("Shapes - x:", x.shape, "y:", y.shape)
        y = y[0, :, :, :]
        y_pred = self.model(x)
        _, prediction = torch.max(y_pred, dim=1)
        print("Prediction shape:", prediction.shape)
        print("Label shape:", y.shape)
        val_loss = torch.nn.functional.cross_entropy(
            prediction.to(dtype=torch.float32), y
        )
        val_loss = torch.tensor(val_loss, requires_grad=True)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer
    
    def train_dataloader(self):
        return self.datamodule.train_dataloader()

    def val_dataloader(self):
        return self.datamodule.val_dataloader()

dm = ClayDataModule(data_dir=data_dir, batch_size=4)
dm.setup()
#val_dl = iter(dm.val_dataloader())

model_unet = UNet(13, 2)
model_clay = CLAYModule.load_from_checkpoint("../mae_epoch-10_val-loss-0.563.ckpt", mask_ratio=0.)
model_clay.eval();
segmentation_model = SegmentationModel(model_unet, dm)

trainer = Trainer(max_epochs=3)
trainer.fit(segmentation_model)

Total number of chips: 19
                                                                                                                                                                                                            

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /Users/lillythomas/Documents/work/clay/lt/benchmark/latest/model/lightning_logs

  | Name  | Type | Params
-------------------------------
0 | model | UNet | 963 K 
-------------------------------
963 K     Trainable params
0         Non-trainable params
963 K     Total params
3.855     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|                                                                                                                                                   | 0/1 [00:00<?, ?it/s]Prediction shape: torch.Size([1, 512, 512])
Label shape: torch.Size([1, 512, 512])
                                                                                                                                                                                                            

  val_loss = torch.tensor(val_loss, requires_grad=True)
/Users/lillythomas/.pyenv/versions/ptod/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:72: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 1. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/Users/lillythomas/.pyenv/versions/ptod/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1933: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:   0%|                                                                                                                                                                        | 0/2 [00:00<?, ?it/s]Prediction shape: torch.Size([1, 512, 512])
Label shape: torch.Size([1, 512, 512])


  loss = torch.tensor(loss, requires_grad=True)


Epoch 0:  50%|██████████████████████████████████████████████████████████████████████▌                                                                      | 1/2 [00:20<00:20, 20.82s/it, loss=190, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                                                                                                                                     | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                                                                                                        | 0/1 [00:00<?, ?it/s][APrediction shape: torch.Size([1, 512, 512])
Label shape: torch.Size([1, 512, 512])

Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.84s/it][A
Epoch 0: 100%|████████████████████████████████████████

In [None]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import numpy

def plot_predictions(model, dataloader):
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x_val_, x_val, y_val = batch, batch["pixels"], batch["labels"]
            y_pred = model(x_val_)
            _, prediction = torch.max(y_pred, dim=1)

            # Convert tensors to NumPy arrays for plotting
            x_val_np = x_val.cpu().numpy()
            y_val_np = y_val.cpu().numpy()
            prediction_np = prediction.cpu().numpy()

            # Plot images and masks
            num_samples = len(x_val)  # Number of samples to visualize
            fig, axes = plt.subplots(1, 3, figsize=(10, 10))

            print(np.unique(y_val_np))
            x_val_np = x_val_np.squeeze()
            x_val_np = x_val_np.transpose(1, 2, 0)
            x_val_np = np.stack((x_val_np[:,:,2], x_val_np[:,:,1], x_val_np[:,:,0]))
            axes[0].imshow(x_val_np.transpose(1,2,0).clip(0, 3000) / 3000) # Plot image
            axes[1].imshow(
                y_val_np.squeeze()
            )  # Plot ground truths
            axes[2].imshow(prediction_np.squeeze())  # Plot model predictions

            plt.show()


# Load the trained model
# loaded_model = UNet(13, 2)  # Initialize the model architecture
# loaded_model.load_state_dict(torch.load('path_to_your_trained_model.pth'))  # Load trained weights

# Run predictions and plot results
plot_predictions(model_unet, dm.val_dataloader())
