# U-Net using pytorch

In [68]:
import numpy as np

X_train_b = np.load("Xtrain2_b.npy")
Y_train_b = np.load("Ytrain2_b.npy")

images = X_train_b.reshape(-1, 48, 48)
masks = Y_train_b.reshape(-1, 48, 48)


Dataset Class for In-Memory Data

In [69]:
import torch
from torch.utils.data import Dataset

class NumpyDataset(Dataset):
    """
    Dataset class for segmentation using reshaped numpy arrays.
    """
    def __init__(self, images, masks, transform=None):
        assert len(images) == len(masks), "Images and masks should have the same length."
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx][np.newaxis, ...]  
        mask = self.masks[idx][np.newaxis, ...] 

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image, mask = transformed["image"], transformed["mask"]

        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)



U-Net Model Definition

In [70]:
import torch
import torch.nn as nn

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.1):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout_rate)  
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        
        for feature in features:
            self.encoder.append(DoubleConv(in_channels, feature))
            in_channels = feature
            
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        
        for feature in reversed(features):
            self.decoder.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(DoubleConv(feature * 2, feature))
        
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        enc_feats = []
        
        for layer in self.encoder:
            x = layer(x)
            enc_feats.append(x)
            x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        
        x = self.bottleneck(x)
        
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x) 
            enc_feat = enc_feats[-(i // 2 + 1)]

            if x.shape != enc_feat.shape:
                x = nn.functional.interpolate(x, size=enc_feat.shape[2:], mode="bilinear", align_corners=True)
                
            x = torch.cat((enc_feat, x), dim=1)
            x = self.decoder[i + 1](x) 
        
        return self.final_conv(x)


Training Function

In [71]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device="cpu"):
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for images, masks in tqdm(train_loader, leave=True):
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, masks).item()

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss / len(train_loader):.4f}, "
              f"Val Loss: {val_loss / len(val_loader):.4f}")



In [90]:
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

def check_accuracy(loader, model, device="cpu"):
    model.eval()
    true_positive = 0
    true_negative = 0
    false_positive = 0
    false_negative = 0

    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            preds = torch.sigmoid(model(images))
            preds = (preds > 0.5).float()

            true_positive += ((preds == 1) & (masks == 1)).sum().item()
            true_negative += ((preds == 0) & (masks == 0)).sum().item()

            false_positive += ((preds == 1) & (masks == 0)).sum().item()
            false_negative += ((preds == 0) & (masks == 1)).sum().item()

    sensitivity = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    specificity = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0

    balanced_accuracy = (sensitivity + specificity) / 2
    print(f"Balanced Accuracy: {balanced_accuracy * 100:.2f}%")

    model.train()

import matplotlib.pyplot as plt

def save_predictions_as_imgs(loader, model, num_images=6, device="cpu"):
    model.eval()
    images_shown = 0

    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            preds = torch.sigmoid(model(images))
            preds = (preds > 0.5).float()

        batch_size = images.shape[0]
        for i in range(batch_size):
            if images_shown >= num_images:
                model.train()
                return
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.title("Image")
            plt.imshow(images[i].cpu().squeeze(), cmap="gray")
            
            plt.subplot(1, 3, 2)
            plt.title("Prediction")
            plt.imshow(preds[i].cpu().squeeze(), cmap="gray")
            
            plt.subplot(1, 3, 3)
            plt.title("Ground Truth")
            plt.imshow(masks[i].cpu().squeeze(), cmap="gray")
            
            plt.show()
            images_shown += 1

    model.train()



Main Code

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_val, Y_train, Y_val = train_test_split(images, masks, test_size=0.2, random_state=42)

train_dataset = NumpyDataset(X_train, Y_train)
val_dataset = NumpyDataset(X_val, Y_val)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

model = UNet(in_channels=1, out_channels=1)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device="cpu")

check_accuracy(val_loader, model, device="cpu")
save_predictions_as_imgs(val_loader, model, num_images=6, device="cpu")

torch.save(model.state_dict(), "u-net_pytorch1.pth")



In [None]:
def load_model(model_path, device="cpu"):
    model = UNet(in_channels=1, out_channels=1)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval() 
    return model

model_path = "u-net_pytorch2.pth"
device = "cpu"
model = load_model(model_path, device=device)

check_accuracy(val_loader, model, device="cpu")

# U-Net with class Weights

In [57]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from sklearn.utils import class_weight

Data Initialization

In [58]:
X_train_b = np.load("Xtrain2_b.npy")  
Y_train_b = np.load("Ytrain2_b.npy")  

X_train_b = X_train_b.reshape(547, 1, 48, 48)  
Y_train_b = Y_train_b.reshape(547, 1, 48, 48) 

X_train_tensor = torch.tensor(X_train_b, dtype=torch.float32)
Y_train_tensor = torch.tensor(Y_train_b, dtype=torch.float32)

X_train, X_val, y_train, y_val = train_test_split(X_train_tensor, Y_train_tensor, test_size=0.2, random_state=42)

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

U-Net Module

In [59]:

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 1, kernel_size=1)  
        )

        self.final_activation = nn.Sigmoid()  

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return self.final_activation(x)  

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)

Dice Loss and Class Weights 

In [60]:
def dice_loss(preds, targets, epsilon=1e-6):
    preds = preds.view(-1)
    targets = targets.view(-1)
    intersection = (preds * targets).sum()
    return 1 - (2. * intersection + epsilon) / (preds.sum() + targets.sum() + epsilon)

def calculate_class_weights(y):
    y_flat = y.view(-1)
    class_weights = class_weight.compute_class_weight('balanced', classes=np.unique(y_flat.numpy()), y=y_flat.numpy())
    return torch.tensor(class_weights, dtype=torch.float32)

class_weights = calculate_class_weights(Y_train_tensor)

criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1].to(device))

Model Training

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 50
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks = masks.to(device).float() 

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks) 
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step(running_loss / len(train_loader)) 

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

Model Validation

In [None]:
model.eval()
val_loss = 0.0

with torch.no_grad():
    for images, masks in val_loader:
        images = images.to(device)
        masks = masks.to(device).float()
        
        outputs = model(images)
        loss = criterion(outputs, masks)
        val_loss += loss.item()

print(f'Validation Loss: {val_loss/len(val_loader):.4f}')

torch.save(model.state_dict(), "unet_weights_with_class_weights3.pth")

Model Evaluation

In [None]:
import numpy as np
import torch
from sklearn.metrics import jaccard_score, f1_score, balanced_accuracy_score, accuracy_score

def evaluate_model(model, val_loader, device):
    model.eval() 
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device).float()
            
            outputs = model(images)
            preds = (outputs > 0.5).float() 
            
            all_preds.append(preds.cpu().numpy())
            all_targets.append(masks.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)

    return all_preds, all_targets

def calculate_metrics(preds, targets):
    preds_flat = preds.flatten()
    targets_flat = targets.flatten()
    
    balanced_acc = balanced_accuracy_score(targets_flat, preds_flat)
    
    accuracy = accuracy_score(targets_flat, preds_flat)

    return balanced_acc, accuracy

preds, targets = evaluate_model(model, val_loader, device)

balanced_acc, accuracy = calculate_metrics(preds, targets)

print(f'Balanced Accuracy: {balanced_acc:.4f}')
print(f'Accuracy: {accuracy:.4f}')

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_results(images, preds, targets, num_samples=5):
    """Visualize input images, predicted masks, and ground truth masks."""
    plt.figure(figsize=(15, 5 * num_samples))
    for i in range(num_samples):
        plt.subplot(num_samples, 3, i * 3 + 1)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.title('Input Image')
        plt.axis('off')

        plt.subplot(num_samples, 3, i * 3 + 2)
        plt.imshow(preds[i].squeeze(), cmap='gray')
        plt.title('Predicted Mask')
        plt.axis('off')

        plt.subplot(num_samples, 3, i * 3 + 3)
        plt.imshow(targets[i].squeeze(), cmap='gray')
        plt.title('Ground Truth Mask')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

visualize_results(X_val, preds, y_val) 

# Pytorch with Class Weights

In [100]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
import matplotlib.pyplot as plt
from tqdm import tqdm

In [101]:
class NumpyDataset(Dataset):
    """
    Dataset class for segmentation using reshaped numpy arrays.
    """
    def __init__(self, images, masks, transform=None):
        assert len(images) == len(masks), "Images and masks should have the same length."
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx][np.newaxis, ...]  
        mask = self.masks[idx][np.newaxis, ...] 

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image, mask = transformed["image"], transformed["mask"]

        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask, dtype=torch.float32)

In [102]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.1):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        
        for feature in features:
            self.encoder.append(DoubleConv(in_channels, feature))
            in_channels = feature
            
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        
        for feature in reversed(features):
            self.decoder.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(DoubleConv(feature * 2, feature))
        
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        enc_feats = []
        
        for layer in self.encoder:
            x = layer(x)
            enc_feats.append(x)
            x = nn.MaxPool2d(kernel_size=2, stride=2)(x)
        
        x = self.bottleneck(x)
        
        for i in range(0, len(self.decoder), 2):
            x = self.decoder[i](x)
            enc_feat = enc_feats[-(i // 2 + 1)]

            if x.shape != enc_feat.shape:
                x = nn.functional.interpolate(x, size=enc_feat.shape[2:], mode="bilinear", align_corners=True)
                
            x = torch.cat((enc_feat, x), dim=1)
            x = self.decoder[i + 1](x)
        
        return self.final_conv(x)

In [103]:
# Function to calculate pos_weight based on pixel balance
def calculate_pos_weight(Y_train):
    num_crater_pixels = (Y_train == 1).sum()
    num_background_pixels = (Y_train == 0).sum()
    pos_weight = num_background_pixels / (num_crater_pixels + 1e-6)
    return pos_weight

In [104]:
# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device="cpu"):
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0

        for images, masks in tqdm(train_loader, leave=True):
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            loss = criterion(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, masks).item()

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss / len(train_loader):.4f}, "
              f"Val Loss: {val_loss / len(val_loader):.4f}")

In [105]:
# Function to calculate balanced accuracy
def check_accuracy(loader, model, device="cpu"):
    model.eval()
    true_positive = 0
    true_negative = 0
    false_positive = 0
    false_negative = 0

    with torch.no_grad():
        for images, masks in loader:
            images, masks = images.to(device), masks.to(device)
            preds = torch.sigmoid(model(images))
            preds = (preds > 0.5).float()

            true_positive += ((preds == 1) & (masks == 1)).sum().item()
            true_negative += ((preds == 0) & (masks == 0)).sum().item()

            false_positive += ((preds == 1) & (masks == 0)).sum().item()
            false_negative += ((preds == 0) & (masks == 1)).sum().item()

    sensitivity = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    specificity = true_negative / (true_negative + false_positive) if (true_negative + false_positive) > 0 else 0

    balanced_accuracy = (sensitivity + specificity) / 2
    print(f"Balanced Accuracy: {balanced_accuracy * 100:.2f}%")

    model.train()

In [106]:
# Visualization function
def save_predictions_as_imgs(loader, model, num_images=6, device="cpu"):
    model.eval()
    images_shown = 0

    for images, masks in loader:
        images, masks = images.to(device), masks.to(device)
        with torch.no_grad():
            preds = torch.sigmoid(model(images))
            preds = (preds > 0.5).float()

        batch_size = images.shape[0]
        for i in range(batch_size):
            if images_shown >= num_images:
                model.train()
                return
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.title("Image")
            plt.imshow(images[i].cpu().squeeze(), cmap="gray")
            
            plt.subplot(1, 3, 2)
            plt.title("Prediction")
            plt.imshow(preds[i].cpu().squeeze(), cmap="gray")
            
            plt.subplot(1, 3, 3)
            plt.title("Ground Truth")
            plt.imshow(masks[i].cpu().squeeze(), cmap="gray")
            
            plt.show()
            images_shown += 1

    model.train()

In [None]:

X_train, X_val, Y_train, Y_val = train_test_split(images, masks, test_size=0.2, random_state=42)

train_dataset = NumpyDataset(X_train, Y_train)
val_dataset = NumpyDataset(X_val, Y_val)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

model = UNet(in_channels=1, out_channels=1)

pos_weight = calculate_pos_weight(Y_train)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device="cpu"))

optimizer = optim.Adam(model.parameters(), lr=1e-4)

train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device="cpu")

check_accuracy(val_loader, model, device="cpu")

save_predictions_as_imgs(val_loader, model, num_images=6, device="cpu")

torch.save(model.state_dict(), "u-net_pytorch_wc.pth")


Model Verification

In [None]:
def load_model(model_path, device="cpu"):
    model = UNet(in_channels=1, out_channels=1)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval() 
    return model

model_path = "u-net_pytorch_wc.pth"
device = "cpu"
model = load_model(model_path, device=device)

check_accuracy(val_loader, model, device="cpu")

# Obtaining Ytest2_b

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def generate_predictions(X_teste, model, device="cpu"):
    num_images = X_teste.shape[0]

    X_teste_reshaped = X_teste.reshape(num_images, 1, 48, 48)
    X_teste_tensor = torch.tensor(X_teste_reshaped, dtype=torch.float32).to(device)

    predictions = []

    with torch.no_grad():
        logits = model(X_teste_tensor) 
        preds = torch.sigmoid(logits)
        binary_preds = (preds > 0.5).float()

        predictions = binary_preds.view(num_images, -1).cpu().numpy()

    return predictions

def visualize_predictions_grid(predictions, num_images=16, cols=4):
    rows = num_images // cols + (num_images % cols > 0)
    fig, axes = plt.subplots(rows, cols, figsize=(12, 12))
    axes = axes.flatten()

    for i in range(num_images):
        axes[i].imshow(predictions[i].reshape(48, 48), cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(f"Predict {i+1}")

    for i in range(num_images, len(axes)):
        axes[i].axis("off")  # Esconde os eixos dos plots vazios, se houver

    plt.tight_layout()
    plt.show()

X_test = np.load("Xtest2_b.npy")
print(X_test.shape)
model_path = "u-net_pytorch_wc.pth"
device = "cpu"
model = load_model(model_path, device=device)

predictions = generate_predictions(X_test, model, device=device)

np.save("Ytest2_b.npy", predictions)

Y_test = np.load("Ytest2_b.npy")
print(Y_test.shape)

visualize_predictions_grid(predictions, num_images=16, cols=4)