In [None]:
cd "/workspace"

In [None]:
!pip install -r "requirements.txt"

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# === Positional Encoding === #
class PositionalEncoding(nn.Module):
    def __init__(self, height, width, channels):
        super(PositionalEncoding, self).__init__()
        self.height = height
        self.width = width
        self.channels = channels

    def forward(self, x):
        # Create positional encodings
        batch_size, _, h, w = x.size()
        pe_x = torch.linspace(0, 1, self.width, device=x.device).unsqueeze(0).repeat(h, 1)
        pe_y = torch.linspace(0, 1, self.height, device=x.device).unsqueeze(1).repeat(1, w)
        pe_x = pe_x.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, h, w)
        pe_y = pe_y.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, h, w)
        return torch.cat([x, pe_x, pe_y], dim=1)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_attention=False):
        super(ResidualBlock, self).__init__()
        self.use_attention = use_attention

        # Convolutional path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=5, padding=2)
        self.norm2 = nn.GroupNorm(8, out_channels)

        # Residual projection for matching dimensions
        self.residual_projection = (
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
            if in_channels != out_channels
            else nn.Identity()
        )

        if use_attention:
            self.attention = nn.MultiheadAttention(out_channels, num_heads=4, batch_first=True)

    def forward(self, x):
        # Project residual if necessary
        residual = self.residual_projection(x)

        # Main convolution path
        x = F.relu(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))

        # Optional attention
        if self.use_attention:
            batch, channels, height, width = x.shape
            x_flat = x.view(batch, channels, height * width).permute(0, 2, 1)  # Flatten for attention
            x = self.attention(x_flat, x_flat, x_flat)[0]
            x = x.permute(0, 2, 1).view(batch, channels, height, width)

        return F.relu(x + residual)


# === Adjusted U-Net === #
# === Adjusted U-Net Decoder === #
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, base_channels=512):
        super(UNet, self).__init__()
        self.encoder1 = ResidualBlock(in_channels, base_channels)
        self.encoder2 = ResidualBlock(base_channels, base_channels * 2)
        self.encoder3 = ResidualBlock(base_channels * 2, base_channels * 4)

        self.middle = ResidualBlock(base_channels * 4, base_channels * 4, use_attention=True)

        self.decoder3 = ResidualBlock(base_channels * 4, base_channels * 2)
        self.decoder2 = ResidualBlock(base_channels * 2, base_channels)
        self.decoder1 = nn.Conv2d(base_channels, out_channels, kernel_size=1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(F.avg_pool2d(enc1, kernel_size=2))
        enc3 = self.encoder3(F.avg_pool2d(enc2, kernel_size=2))

        mid = self.middle(enc3)

        dec3 = self.decoder3(F.interpolate(mid, scale_factor=2, mode='bilinear', align_corners=False))
        dec2 = self.decoder2(F.interpolate(dec3, scale_factor=2, mode='bilinear', align_corners=False))
        dec1 = self.decoder1(F.interpolate(dec2, size=x.size()[-2:], mode='bilinear', align_corners=False))  # Match input size

        return dec1


# === Diffusion Model with Adjustments === #
class DDPM(nn.Module):
    def __init__(self, image_size=313, channels=502, timesteps=1000):
        super(DDPM, self).__init__()
        self.timesteps = timesteps
        self.image_size = image_size

        # Use UNet directly
        self.unet = UNet(in_channels=channels, out_channels=channels)

    def forward(self, x, t):
        # Forward pass through U-Net
        return self.unet(x)

class DiffusionTraining:
    def __init__(self, model, timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.model = model
        self.timesteps = timesteps

        # Beta schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def forward_diffusion(self, x_0, t):
        noise = torch.randn_like(x_0)
        # Ensure alpha_bars is on the same device as t
        alpha_bar_t = self.alpha_bars.to(t.device)[t].view(-1, 1, 1, 1)
        return torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise, noise

    def train_step(self, x_0, optimizer):
        t = torch.randint(0, self.timesteps, (x_0.size(0),), device=x_0.device)
        x_t, noise = self.forward_diffusion(x_0, t)
        noise_pred = self.model(x_t, t)
        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

# === Initialize Model === #
image_size = 313
channels = 502
model = DDPM(image_size, channels)
print("Done")

KeyboardInterrupt: 

In [None]:
import h5py
high_res_file = "./padded_high_res_grids_3.h5"
with h5py.File(high_res_file, "r") as high_res_f:
        high_res_data = high_res_f["0"][()]
print(high_res_data[501,100:200,100:200])
#Testing for file corruption and encoding error! All good now

In [None]:
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

# === PairedGridDatasetH5 Class === #
class PairedGridDatasetH5(Dataset):
    def __init__(self, high_res_file, low_res_file, keys):
        """
        Dataset for paired high-res and low-res grids preloaded into RAM.
        :param high_res_file: Path to the high-res HDF5 file.
        :param low_res_file: Path to the low-res HDF5 file.
        :param keys: List of keys identifying the datasets to load.
        """
        self.data = []
        self.keys = keys

        print("Preloading data into RAM...")
        with h5py.File(high_res_file, "r") as high_res_f, h5py.File(low_res_file, "r") as low_res_f:
            lf=low_res_f["low_res"]
            for key in keys:
                # Load data into memory
                low_res_grid = lf[key][...]  # Entire low-res grid
                high_res_grid = high_res_f[key][...]  # Entire high-res grid
                self.data.append((low_res_grid, high_res_grid))
        print("Data preloaded.")

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        """
        Access preloaded data and create a mask dynamically.
        """
        low_res_grid, high_res_grid = self.data[idx]

        # Convert grids to PyTorch tensors
        low_res_grid = torch.tensor(low_res_grid, dtype=torch.float32)
        high_res_grid = torch.tensor(high_res_grid, dtype=torch.float32)

        # Create mask
        mask = (high_res_grid[0] != -1).float()  # Use the first channel as reference
        mask = mask.unsqueeze(0).expand(502, -1, -1)  # Match 502 channels
        return low_res_grid, high_res_grid, mask
        self.keys = keys

# === File Paths === #
high_res_file = "./padded_high_res_grids_3.h5"
low_res_file = "./padded_intermediate_BAYESSPACE.h5"

# === Load Keys and Shuffle === #
with h5py.File(high_res_file, "r") as h5f:
    all_keys = list(h5f.keys())  #The keys match between high_res and low_res files

# Shuffle keys
np.random.shuffle(all_keys)

# === Train/Val/Test Split === #
train_ratio, val_ratio = 0.8, 0.1
train_keys = all_keys[:int(train_ratio * len(all_keys))]
val_keys = all_keys[int(train_ratio * len(all_keys)):int((train_ratio + val_ratio) * len(all_keys))]
test_keys = all_keys[int((train_ratio + val_ratio) * len(all_keys)):]

# === Create Datasets === #
train_dataset = PairedGridDatasetH5(high_res_file, low_res_file, train_keys)
val_dataset = PairedGridDatasetH5(high_res_file, low_res_file, val_keys)
test_dataset = PairedGridDatasetH5(high_res_file, low_res_file, test_keys)
# === Create DataLoaders === #
batch_size = 5
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# === Sanity Check: Verify Data Loading === #
print(f"Train set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")
for batch in train_loader:
    low_res, high_res, mask = batch
    print(f"Low-res shape: {low_res.shape}")
    print(f"High-res shape: {high_res.shape}")
    if mask is not None:
        print(f"Mask shape: {mask.shape}")
    break
print("Done")


In [None]:
#DEBUGGING CELL
with h5py.File(high_res_file, "r") as high_res_f, h5py.File(low_res_file, "r") as low_res_f:
            lf=low_res_f["low_res"]
            print(lf["0"][...])
            #no difference between () and ... for conversion time

In [None]:
# Initialize DDPM model
from torch.optim import Adam

device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs=20
image_size = 313
input_channels = 502
model = DDPM(image_size, input_channels).to(device)

# Optimizer
optimizer = Adam(model.parameters(), lr=5e-4)

# Diffusion Framework for Training
trainer = DiffusionTraining(model)

# Checkpoint path
checkpoint_path = '/content/drive/MyDrive/checkpoints/ddpm_checkpoint.pth'


def save_checkpoint(epoch, model, optimizer, train_losses, val_losses, checkpoint_path):
    """
    Save training checkpoint.
    :param epoch: Current epoch.
    :param model: DDPM model instance.
    :param optimizer: Optimizer instance.
    :param train_losses: List of training losses.
    :param val_losses: List of validation losses.
    :param checkpoint_path: Path to save the checkpoint.
    """
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
    }, checkpoint_path)


# Training loop with memory optimizations
accumulation_steps = 4  # Simulate larger batch size by accumulating gradients

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()
    optimizer.zero_grad()
    epoch_train_loss = 0

    for i, (lr_data, hr_data, mask) in enumerate(train_loader):
        # Move data to GPU when needed
        lr_data = lr_data.to(device, non_blocking=True)
        hr_data = hr_data.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)

        # Forward diffusion process
        t = torch.randint(0, trainer.timesteps, (lr_data.size(0),), device=device)
        x_t, noise = trainer.forward_diffusion(hr_data, t)

        # Predict noise
        noise_pred = model(x_t, t)

        # Masked loss calculation
        loss = F.mse_loss(noise_pred * mask, noise * mask, reduction='sum') / mask.sum()

        # Normalize loss for accumulation
        loss = loss / accumulation_steps
        loss.backward()

        # Perform optimizer step after accumulating gradients
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        epoch_train_loss += loss.item()

        # Free GPU memory
        del lr_data, hr_data, mask, x_t, noise, noise_pred
        torch.cuda.empty_cache()

    # Final step for remaining gradients (if total batches are not divisible by accumulation_steps)
    if len(train_loader) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()

    # Log training loss
    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Train Loss: {avg_train_loss:.4f}")

    # Validation loop
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for lr_data, hr_data, mask in val_loader:
            lr_data = lr_data.to(device, non_blocking=True)
            hr_data = hr_data.to(device, non_blocking=True)
            mask = mask.to(device, non_blocking=True)

            # Forward diffusion process
            t = torch.randint(0, trainer.timesteps, (lr_data.size(0),), device=device)
            x_t, noise = trainer.forward_diffusion(hr_data, t)
            noise_pred = model(x_t, t)

            # Masked validation loss
            val_loss = F.mse_loss(noise_pred * mask, noise * mask, reduction='sum') / mask.sum()
            epoch_val_loss += val_loss.item()

            # Free GPU memory
            del lr_data, hr_data, mask, x_t, noise, noise_pred
            torch.cuda.empty_cache()

    avg_val_loss = epoch_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    print(f"Validation Loss: {avg_val_loss:.4f}")

    # Save checkpoint
    save_checkpoint(epoch, model, optimizer, train_losses, val_losses, checkpoint_path)

# Test evaluation
model.eval()
epoch_test_loss = 0
with torch.no_grad():
    for lr_data, hr_data, mask in test_loader:
        lr_data = lr_data.to(device, non_blocking=True)
        hr_data = hr_data.to(device, non_blocking=True)
        mask = mask.to(device, non_blocking=True)

        # Forward diffusion process
        t = torch.randint(0, trainer.timesteps, (lr_data.size(0),), device=device)
        x_t, noise = trainer.forward_diffusion(hr_data, t)
        noise_pred = model(x_t, t)

        # Masked test loss
        test_loss = F.mse_loss(noise_pred * mask, noise * mask, reduction='sum') / mask.sum()
        epoch_test_loss += test_loss.item()

        # Free GPU memory
        del lr_data, hr_data, mask, x_t, noise, noise_pred
        torch.cuda.empty_cache()

avg_test_loss = epoch_test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18

# === Utility Functions === #
def align_tensor_sizes(tensor_a, tensor_b):
    """Ensure tensor_a matches the dimensions of tensor_b."""
    _, _, height_t1, width_t1 = tensor_a.size()
    _, _, height_t2, width_t2 = tensor_b.size()

    if height_t1 != height_t2 or width_t1 != width_t2:
        tensor_a = F.interpolate(tensor_a, size=(height_t2, width_t2), mode='bilinear', align_corners=False)

    return tensor_a

# === ResNetUNet Definition === #
class ResNetUNet(nn.Module):
    def __init__(self, pretrained_encoder=None, out_channels=502):
        super(ResNetUNet, self).__init__()

        # Use the provided InterpolationModel as the encoder
        self.encoder = pretrained_encoder

        # Decoder layers
        self.decoder4 = nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=1)
        self.decoder3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1)
        self.decoder2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1)
        self.decoder1 = nn.ConvTranspose2d(64, out_channels, kernel_size=2, stride=2, padding=1)

        # Skip connections
        self.skip4 = nn.Conv2d(256, 256, kernel_size=1)
        self.skip3 = nn.Conv2d(128, 128, kernel_size=1)
        self.skip2 = nn.Conv2d(64, 64, kernel_size=1)

    def forward(self, x, original_size):
        # Pass input through the InterpolationModel encoder
        enc1 = self.encoder.encoder.conv1(x)  # Initial convolution
        enc1 = self.encoder.encoder.bn1(enc1)  # Batch normalization
        enc1 = F.relu(enc1)  # Add ReLU activation explicitly

        enc2 = self.encoder.encoder.layer1(enc1)  # Residual layer 1
        enc3 = self.encoder.encoder.layer2(enc2)  # Residual layer 2
        enc4 = self.encoder.encoder.layer3(enc3)  # Residual layer 3

        # Bottleneck
        bottleneck = self.encoder.encoder.layer4(enc4)  # Residual layer 4

        # Decoder with skip connections
        dec4 = self.decoder4(bottleneck)
        dec4 = align_tensor_sizes(dec4, self.skip4(enc4)) + self.skip4(enc4)

        dec3 = self.decoder3(dec4)
        dec3 = align_tensor_sizes(dec3, self.skip3(enc3)) + self.skip3(enc3)

        dec2 = self.decoder2(dec3)
        dec2 = align_tensor_sizes(dec2, self.skip2(enc2)) + self.skip2(enc2)

        dec1 = self.decoder1(dec2)

        # Final output resizing to match original size
        return dec1[:, :, :original_size[0], :original_size[1]]

# === Function to Load and Remap State_dict === #
def load_and_remap_state_dict(model, checkpoint_path):
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith("decoder."):
            # Map decoder keys
            new_key = key.replace("decoder.0", "decoder4").replace("decoder.2", "decoder3") \
                         .replace("decoder.4", "decoder2").replace("decoder.6", "decoder1")
        elif key.startswith("skip"):
            # Map skip connection keys
            new_key = key.replace("skip", "skip")
        elif key == "encoder.conv1.weight":
            # Handle conv1 shape mismatch
            new_key = key
        else:
            # Keep other keys unchanged
            new_key = key
        new_state_dict[new_key] = value

    # Load remapped state_dict
    model.load_state_dict(new_state_dict, strict=False)
    print("State_dict successfully remapped and loaded.")

# === Initialize Model and Load Weights === #
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resnet=InterpolationModel()
resnet.to(device)
# Path to checkpoint
checkpoint_path = './interpolation_model.pth'
load_and_remap_state_dict(resnet, checkpoint_path)
model = ResNetUNet(out_channels=502,pretrained_encoder=resnet)
model.to(device)

print("Models are functional")



State_dict successfully remapped and loaded.
Models are functional


In [None]:
import torch.nn.functional as F
from torch.optim import Adam
from tasks.spatial_transcriptomics_dataset import MemoryEfficientPairedLoader

In [20]:

train_high_res_files = [f"binned_data/high_res_grids_{i}.npz" for i in [4, 12,  14, 22]]
train_low_res_files = [f"binned_data/low_res_grids_{i}.npz" for i in [4, 12,  14, 22]]
val_high_res_files = [f"binned_data/high_res_grids_{i}.npz" for i in [9, 27]]
val_low_res_files = [f"binned_data/low_res_grids_{i}.npz" for i in [9, 27]]
test_high_res_files = [f"binned_data/high_res_grids_{i}.npz" for i in [24, 2]]
test_low_res_files = [f"binned_data/low_res_grids_{i}.npz" for i in [24, 2]]

#utilizing my lazy loaders
print("Creating loaders")
train_loader = MemoryEfficientPairedLoader(train_high_res_files, train_low_res_files, batch_size=5)
val_loader = MemoryEfficientPairedLoader(val_high_res_files, val_low_res_files, batch_size=5)
test_loader = MemoryEfficientPairedLoader(test_high_res_files, test_low_res_files, batch_size=5)
print("Created loaders")

optimizer = Adam(model.parameters(), lr=3e-4) #pick higher start

#saving function copy pasted from above
def save_checkpoint(epoch, model, optimizer, train_losses, val_losses, checkpoint_path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
    }, checkpoint_path)

#train and validate
num_epochs = 10
train_losses, val_losses = [], []
checkpoint_path = 'checkpoints/resnet_unet_checkpoint.pth'

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()
    epoch_train_loss = 0

    for sample in train_loader:
        lr_data = sample['low_res_tensor'].to(device)
        hr_data = sample['high_res_tensor'].to(device)
        mask = sample['mask'].to(device)

        # Forward pass
        noise_pred = model(lr_data, original_size=(313, 313))
        noise_pred = align_tensor_sizes(noise_pred, mask) #dim align

        #masked loss
        total_loss = 0.0
        for i in range(noise_pred.shape[0]):
            mask2 = mask[i:i + 1]
            loss = F.mse_loss(noise_pred[i] * mask2, hr_data[i] * mask2, reduction='sum')
            total_loss += loss

        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_train_loss += total_loss.item()

    avg_train_loss = epoch_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Train Loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            lr_data = batch['low_res_tensor'].to(device)
            hr_data = batch['high_res_tensor'].to(device)
            mask = batch['mask'].to(device)

            noise_pred = model(lr_data, original_size=(313, 313))
            noise_pred = align_tensor_sizes(noise_pred, mask)

            val_loss = F.mse_loss(noise_pred * mask, hr_data * mask, reduction='sum') / mask.sum()
            epoch_val_loss += val_loss.item()

    avg_val_loss = epoch_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    print(f"Validation Loss: {avg_val_loss:.4f}")

    # Save checkpoint
    save_checkpoint(epoch, model, optimizer, train_losses, val_losses, checkpoint_path)

# Test evaluation
model.eval()
epoch_test_loss = 0
with torch.no_grad():
    for batch in test_loader:
        lr_data = batch['low_res_tensor'].to(device)
        hr_data = batch['high_res_tensor'].to(device)
        mask = batch['mask'].to(device)

        noise_pred = model(lr_data, original_size=(313, 313))
        noise_pred = align_tensor_sizes(noise_pred, mask)

        test_loss = F.mse_loss(noise_pred * mask, hr_data * mask, reduction='sum') / mask.sum()
        epoch_test_loss += test_loss.item()

avg_test_loss = epoch_test_loss / len(test_loader)
print(f"Test Loss: {avg_test_loss:.4f}")

# Save final checkpoint
save_checkpoint(num_epochs, model, optimizer, train_losses, val_losses, checkpoint_path)


Creating loaders
Created loaders
Epoch 1/10
Train Loss: 46735657.4000
Validation Loss: 0.3457
Epoch 2/10
Train Loss: 40633651.8000
Validation Loss: 0.3237
Epoch 3/10
Train Loss: 37912029.2000
Validation Loss: 0.3064
Epoch 4/10
Train Loss: 37266074.8000
Validation Loss: 0.3027
Epoch 5/10
Train Loss: 36937856.7000
Validation Loss: 0.3048
Epoch 6/10
Train Loss: 35986738.8000
Validation Loss: 0.2943
Epoch 7/10
Train Loss: 36441739.4000
Validation Loss: 0.2967
Epoch 8/10
Train Loss: 36183274.8000
Validation Loss: 0.2996
Epoch 9/10
Train Loss: 35351123.1000
Validation Loss: 0.3052
Epoch 10/10
Train Loss: 35606189.0000
Validation Loss: 0.2860
Test Loss: 0.3157


In [1]:
!pip install matplotlib

Collecting matplotlib
  Downloading matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.55.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (165 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m165.1/165.1 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (6.3 kB)
Downloading matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m 

In [None]:
data = np.load("binned_data/low_res_grids_1.npz") #Pick a file we have not used 
low_res_grid = data["grids"][4]  # Extract the specific grid
mask2=(low_res_grid[200]>0)
low_res_tensor = torch.tensor(low_res_grid, dtype=torch.float32).unsqueeze(0).to(device)  # [1, 502, 313, 313]

#inference/superresolve the lr image
model.eval()
with torch.no_grad():#disable gradient to avoid unnecessary computations
    high_res_pred = model(x=low_res_tensor, original_size=(313, 313))
grid=np.array(high_res_pred.cpu())

mask = (grid[0,200,:,:] > 0)
channel_data = grid[0,200, :, :]*mask*mask2
plt.figure(figsize=(10, 10))
myim=plt.imshow(channel_data, cmap='viridis', interpolation='nearest')
cbar = plt.colorbar(myim, label=f'Channel {501} Intensity', shrink=0.8)

cbar.ax.tick_params(labelsize=12)  # Make color bar labels larger
cbar.set_label(f'Channel {200} Intensity', fontsize=14)  # Color bar title larger

plt.title(f'Visualization of ResUnet #10_256_9E4 Upscaled Synthetic Visium Gene Expression Values for: Spatial Coord X #{201}', fontsize=12)
plt.xlabel('X Index', fontsize=14)
plt.ylabel('Y Index', fontsize=14)
plt.show()

In [None]:
data = np.load("binned_data/low_res_grids_28.npz") #Pick a file we have not used 
low_res_grid = data["grids"][4]  # Extract the specific grid
mask2=(low_res_grid[200]>0)
low_res_tensor = torch.tensor(low_res_grid, dtype=torch.float32).unsqueeze(0).to(device)  # [1, 502, 313, 313]

#inference/superresolve the lr image
model.eval()
with torch.no_grad():#disable gradient to avoid unnecessary computations
    high_res_pred = model(x=low_res_tensor, original_size=(313, 313))
grid=np.array(high_res_pred.cpu())

mask = (grid[0,200,:,:] > 0)
channel_data = grid[0,200, :, :]*mask*mask2
plt.figure(figsize=(10, 10))
myim=plt.imshow(channel_data, cmap='viridis', interpolation='nearest')
cbar = plt.colorbar(myim, label=f'Channel {200} Intensity', shrink=0.5)

cbar.ax.tick_params(labelsize=12)  # Make color bar labels larger
cbar.set_label(f'Channel {200} Intensity', fontsize=14)  # Color bar title larger

plt.title(f'Visualization of ResUnet #12_256_9E4 Upscaled Synthetic Visium Gene Expression Values for: Spatial X Cord #{200}', fontsize=12)
plt.xlabel('X Index', fontsize=14)
plt.ylabel('Y Index', fontsize=14)
plt.show()

In [None]:
!pip install tensorflow
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr, mean_squared_error as mse
import tensorflow as tf
def load_grid_from_npz(npz_file, grid_key):
    #Loads grids from an npz file.
    with np.load(npz_file, allow_pickle=True) as data:
        return data[grid_key]

def compute_metrics(high_res_tensor, low_res_tensor):

    #Computes SSIM, PSNR, and MSE between high-res and low-res grids across all channels.

    ssim_values = []
    psnr_values = []
    mse_values = []

    for channel in range(502):  # Iterate over all channels
        high_res_channel = high_res_tensor[channel, :, :]
        low_res_channel = low_res_tensor[channel, :, :]
        data_range = max(high_res_channel.max(), low_res_channel.max()) - min(high_res_channel.min(), low_res_channel.min())
        channel_ssim, _ = ssim(high_res_channel, low_res_channel, data_range=data_range, full=True)
        ssim_values.append(channel_ssim)

        channel_psnr = psnr(high_res_channel, low_res_channel, data_range=data_range)
        psnr_values.append(channel_psnr)

        channel_mse = mse(high_res_channel, low_res_channel)
        mse_values.append(channel_mse)
    mean_ssim = np.mean(ssim_values)
    mean_psnr = np.mean(psnr_values)
    mean_mse = np.mean(mse_values)

    return ssim_values, psnr_values, mse_values, mean_ssim, mean_psnr, mean_mse

#File paths and grid index
high_res_file = 'binned_data/high_res_grids_28.npz'
low_res_file = 'binned_data/high_res_grids_28.npz'
grid_key = "grids"
grid_index = 4  # Example grid index

#Load the grids
high_res_grids = load_grid_from_npz(high_res_file, grid_key)
low_res_grids = load_grid_from_npz(low_res_file, grid_key)

high_res_grid = high_res_grids[grid_index]
low_res_grid = low_res_grids[grid_index]
mask = (low_res_grid != -1)
high_res_grid=high_res_grid*mask
low_res_grid=low_res_grid*mask
ssim_values, psnr_values, mse_values, mean_ssim, mean_psnr, mean_mse = compute_metrics(high_res_grid, low_res_grid)
ssim_values, psnr_values, mse_values, mean_ssim, mean_psnr, mean_mse = compute_metrics(high_res_grid, channel_data)
# Display results
print(f"SSIM values for -30 channel: {ssim_values[100]}")
print(f"Mean SSIM across all channels: {mean_ssim}")
print(f"PSNR values for 30channel: {psnr_values[100]}")
print(f"Mean PSNR across all channels: {mean_psnr}")
print(f"MSE values for -30 channel: {mse_values[100]}")
print(f"Mean MSE across all channels: {mean_mse}")


In [18]:
class InterpolationModel(nn.Module):
    def __init__(self):
        super(InterpolationModel, self).__init__()
        # Encoder: Use ResNet18 for feature extraction
        self.encoder = resnet18(weights=None)  #No pretrained weights
        self.encoder.conv1 = nn.Conv2d(502, 64, kernel_size=5, stride=2, padding=3, bias=False)
        self.encoder.fc = nn.Identity()  #Remove FC layer for features

        # Decoder: Upsample to match [313, 313]
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2, padding=1),  
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),  
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2, padding=1),   
            nn.ReLU(),
            nn.ConvTranspose2d(32, 502, kernel_size=2, stride=1, padding=1),  
            nn.Upsample(size=(313, 313), mode='bilinear', align_corners=True) #final upscale to [313x313]
        )

    def forward(self, x):
        features = self.encoder(x)  # Extract features
        features = features.view(features.size(0), -1, 1, 1)  # Reshape for decoding
        output = self.decoder(features)  #Decode to high-resolution size
        return output


In [15]:
! pip install torch

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
