In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [3]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return x

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv_block = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv_block(x)
        p = self.pool(x)
        return x, p

In [5]:
class ConvLSTM(nn.Module):
    def __init__(self, in_channels, hidden_channels, kernel_size, padding):
        super(ConvLSTM, self).__init__()
        self.conv = nn.Conv2d(in_channels + hidden_channels, 4 * hidden_channels, kernel_size, padding=padding)

    def forward(self, x, hidden):
        combined_input = torch.cat([x, hidden[0]], dim=1)
        gates = self.conv(combined_input)
        ingate, forgetgate, cellgate, outgate = torch.split(gates, gates.size(1) // 4, dim=1)
        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)
        cell = forgetgate * hidden[1] + ingate * cellgate
        hidden = outgate * torch.tanh(cell)
        return hidden, cell

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv_lstm = ConvLSTM(in_channels=out_channels, hidden_channels=out_channels, kernel_size=3, padding=1)

    def forward(self, x, skip, hidden=None, cell=None):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        if hidden is not None and cell is not None:
            x, _ = self.conv_lstm(x, (hidden, cell))
        return x


In [7]:
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=2, stride=2, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        psi = F.interpolate(psi, size=x.size()[2:], mode='bilinear', align_corners=True)
        return x * psi

In [8]:
class ASPPBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPPBlock, self).__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.conv3x3_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.conv3x3_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.conv3x3_3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        self.output_conv = nn.Sequential(
            nn.Conv2d(out_channels * 4, out_channels, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.conv1x1(x)
        x2 = self.conv3x3_1(x)
        x3 = self.conv3x3_2(x)
        x4 = self.conv3x3_3(x)
        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.output_conv(x)
        return x

In [9]:
class SiameseUNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SiameseUNet, self).__init__()
        self.encoder1_1 = EncoderBlock(in_channels, 32)  
        self.encoder1_2 = EncoderBlock(32, 64)          
        self.encoder1_3 = EncoderBlock(64, 128)        
        self.encoder1_4 = EncoderBlock(128, 256)        

        self.encoder2_1 = EncoderBlock(in_channels, 32)  
        self.encoder2_2 = EncoderBlock(32, 64)          
        self.encoder2_3 = EncoderBlock(64, 128)         
        self.encoder2_4 = EncoderBlock(128, 256)        

        self.aspp = ASPPBlock(256 * 2, 512)             

        self.attn1 = AttentionBlock(512, 256 * 2, 256)  
        self.attn2 = AttentionBlock(256, 128 * 2, 128)  
        self.attn3 = AttentionBlock(128, 64 * 2, 64)    
        self.attn4 = AttentionBlock(64, 32 * 2, 32)     

        self.decoder1 = DecoderBlock(512 + 512, 256)    
        self.decoder2 = DecoderBlock(256 + 256, 128)    
        self.decoder3 = DecoderBlock(128 + 128, 64)     
        self.decoder4 = DecoderBlock(64 + 64, 32)       

        self.final_conv = nn.Conv2d(32, out_channels, kernel_size=1)  # Assuming output channels remain 1

    def forward(self, x1, x2):
        # Encoder Path
        s1_1, p1_1 = self.encoder1_1(x1)
        s2_1, p2_1 = self.encoder1_2(p1_1)
        s3_1, p3_1 = self.encoder1_3(p2_1)
        s4_1, p4_1 = self.encoder1_4(p3_1)

        s1_2, p1_2 = self.encoder2_1(x2)
        s2_2, p2_2 = self.encoder2_2(p1_2)
        s3_2, p3_2 = self.encoder2_3(p2_2)
        s4_2, p4_2 = self.encoder2_4(p3_2)

        # ASPP
        concatenated = torch.cat((p4_1, p4_2), dim=1)
        b1 = self.aspp(concatenated)

        # Decoder Path with Attention
        attn_s4 = self.attn1(b1, torch.cat((s4_1, s4_2), dim=1))
        d1 = self.decoder1(b1, attn_s4)

        attn_s3 = self.attn2(d1, torch.cat((s3_1, s3_2), dim=1))
        d2 = self.decoder2(d1, attn_s3)

        attn_s2 = self.attn3(d2, torch.cat((s2_1, s2_2), dim=1))
        d3 = self.decoder3(d2, attn_s2)

        attn_s1 = self.attn4(d3, torch.cat((s1_1, s1_2), dim=1))
        d4 = self.decoder4(d3, attn_s1)

        outputs = torch.sigmoid(self.final_conv(d4))
        return outputs

# Define the model
model = SiameseUNet(in_channels=3, out_channels=1)

In [10]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class CustomWHUDataset(Dataset):
    def __init__(self, base_dir, image_list_file, transform=None):
        self.base_dir = base_dir
        self.image_list_file = image_list_file
        self.transform = transform

        # Read the file containing image names
        with open(image_list_file, 'r') as f:
            self.image_names = [line.strip() for line in f.readlines()]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Get the image name from the list
        image_name = self.image_names[idx]

        # Construct the full paths for time1, time2, and label images
        time1_path = os.path.join(self.base_dir, 'A', f"{image_name}")
        time2_path = os.path.join(self.base_dir, 'B', f"{image_name}")
        label_path = os.path.join(self.base_dir, 'label', f"{image_name}")

        # Open the images
        time1_image = Image.open(time1_path).convert('RGB')
        time2_image = Image.open(time2_path).convert('RGB')
        label_image = Image.open(label_path).convert('L')

        # Apply transformations if provided
        if self.transform:
            time1_image = self.transform(time1_image)
            time2_image = self.transform(time2_image)
            label_image = self.transform(label_image)

        return (time1_image, time2_image), label_image

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# Define base directory and list files
base_dir = '/kaggle/input/whu-rs-change-dataset'
train_list_file = os.path.join(base_dir, 'list', 'train.txt')
val_list_file = os.path.join(base_dir, 'list', 'val.txt')
test_list_file = os.path.join(base_dir, 'list', 'test.txt')

# Create Dataset instances
train_dataset = CustomWHUDataset(base_dir, train_list_file, transform=transform)
val_dataset = CustomWHUDataset(base_dir, val_list_file, transform=transform)
test_dataset = CustomWHUDataset(base_dir, test_list_file, transform=transform)

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [11]:
import torch.optim as optim
# Initialize the  loss function, and optimizer

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Define device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

SiameseUNet(
  (encoder1_1): EncoderBlock(
    (conv_block): ConvBlock(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder1_2): EncoderBlock(
    (conv_block): ConvBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, c

In [12]:
def save_checkpoint(state, filename='model_checkpoint.pth.tar'):
    print('=>saving')
    torch.save(state, filename)

# Function to load the model checkpoint
def load_checkpoint(filename='model_checkpoint.pth.tar', model=None, optimizer=None):
    if os.path.isfile(filename):
        print(f"Loading checkpoint '{filename}'")
        checkpoint = torch.load(filename)
        if model is not None:
            model.load_state_dict(checkpoint['state_dict'])
        if optimizer is not None:
            optimizer.load_state_dict(checkpoint['optimizer'])
        return checkpoint
    else:
        print(f"No checkpoint found at '{filename}'")
        return None

In [13]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs, patience, model_save_path):
    best_loss = float('inf')
    patience_counter = 0
    start_epoch = 0

    # Load checkpoint if it exists
    checkpoint = load_checkpoint(model_save_path, model, optimizer)
    if checkpoint:
        start_epoch = checkpoint.get('epoch', 0) + 1
        best_loss = checkpoint.get('best_loss', float('inf'))

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_losses = []

        for (time1_image, time2_image), target in train_loader:
            time1_image, time2_image, target = time1_image.to(device), time2_image.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(time1_image, time2_image)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        val_loss = evaluate(model, val_loader, criterion)
        print(f'Epoch {epoch+1}, Train Loss: {np.mean(train_losses):.4f}, Val Loss: {val_loss:.4f}')

        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_loss': best_loss
            }, model_save_path)
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print("Early stopping")
            break


In [14]:
# Function to evaluate the model
def evaluate(model, val_loader, criterion):
    model.eval()
    val_losses = []
    with torch.no_grad():
        for (time1, time2), target in val_loader:
            time1, time2, target = time1.to(device), time2.to(device), target.to(device)
            output = model(time1, time2)
            loss = criterion(output, target)
            val_losses.append(loss.item())
    return np.mean(val_losses)


In [15]:
# Function to count the number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# Function to count the number of trainable parameters
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Print the number of parameters
total_params = count_parameters(model)
trainable_params = count_trainable_parameters(model)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

Total parameters: 21804365
Trainable parameters: 21804365


In [16]:
# Model path
model_save_path = '/kaggle/working/siamese_unet_best_model.pth'


In [17]:
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, patience=3, model_save_path=model_save_path)

No checkpoint found at '/kaggle/working/siamese_unet_best_model.pth'
Epoch 1, Train Loss: 0.5735, Val Loss: 0.4955
=>saving
Epoch 2, Train Loss: 0.4452, Val Loss: 0.4247
=>saving
Epoch 3, Train Loss: 0.4132, Val Loss: 0.4194
=>saving
Epoch 4, Train Loss: 0.4014, Val Loss: 0.3931
=>saving
Epoch 5, Train Loss: 0.3887, Val Loss: 0.3820
=>saving
Epoch 6, Train Loss: 0.3811, Val Loss: 0.4240
Epoch 7, Train Loss: 0.3764, Val Loss: 0.3635
=>saving
Epoch 8, Train Loss: 0.3646, Val Loss: 0.3572
=>saving
Epoch 9, Train Loss: 0.3597, Val Loss: 0.4892
Epoch 10, Train Loss: 0.3549, Val Loss: 0.3536
=>saving


In [18]:
# Load the best model
load_checkpoint(model_save_path, model, optimizer)

# Evaluate on the validation set
val_loss = evaluate(model, val_loader, criterion)
print("Validation Loss:", val_loss)


Loading checkpoint '/kaggle/working/siamese_unet_best_model.pth'


  checkpoint = torch.load(filename)


Validation Loss: 0.35364921887715656


In [19]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, jaccard_score

def calculate_metrics(y_true, y_pred, threshold=0.5):
    y_pred = (y_pred > threshold).float()
    
    y_true_np = y_true.cpu().numpy().flatten()
    y_pred_np = y_pred.cpu().numpy().flatten()
    
    accuracy = accuracy_score(y_true_np, y_pred_np)
    precision = precision_score(y_true_np, y_pred_np, average='binary', zero_division=1)
    recall = recall_score(y_true_np, y_pred_np, average='binary', zero_division=1)
    f1 = f1_score(y_true_np, y_pred_np, average='binary', zero_division=1)
    iou = jaccard_score(y_true_np, y_pred_np, average='binary', zero_division=1)
    
    return accuracy, precision, recall, f1, iou


In [20]:
def test_metric_model(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0.0
    all_true = []
    all_pred = []
    
    with torch.no_grad():
        for (time1, time2), target in test_loader:
            time1, time2, target = time1.to(device), time2.to(device), target.to(device)
            output = model(time1, time2)
            loss = criterion(output, target)
            test_loss += loss.item()

            all_true.append(target)
            all_pred.append(output)
    
    # Concatenate all batches
    all_true = torch.cat(all_true, dim=0)
    all_pred = torch.cat(all_pred, dim=0)
    
    # Calculate metrics
    accuracy, precision, recall, f1, iou = calculate_metrics(all_true, all_pred)
    
    return test_loss / len(test_loader), accuracy, precision, recall, f1, iou

# Assuming model, test_loader, criterion, and device are defined as in your code
test_loss, accuracy, precision, recall, f1, iou = test_metric_model(model, test_loader, criterion, device)
print(f'Test Loss: {test_loss:.4f}')
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'IoU: {iou:.4f}')

Test Loss: 0.3526
Accuracy: 0.9848
Precision: 0.8116
Recall: 0.7964
F1 Score: 0.8039
IoU: 0.6721


In [21]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torch.utils.data import DataLoader

def calculate_iou(predictions, labels):
    intersection = np.logical_and(predictions, labels).sum()
    union = np.logical_or(predictions, labels).sum()
    iou = intersection / union if union != 0 else 0
    return iou


load_checkpoint(model_save_path, model, optimizer=None)

# Set the model to evaluation mode
model.eval()

# Initialize lists to store evaluation results
accuracies = []
precisions = []
recalls = []
f1_scores = []
ious=[]

# Ensure no gradients are computed during evaluation
with torch.no_grad():
    # Iterate through the test dataloader
    for (time1_image, time2_image), labels in test_loader:
        # Convert inputs and labels to PyTorch tensors and move to the appropriate device
        time1_image = time1_image.to(device)
        time2_image = time2_image.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(time1_image, time2_image)
        
        # Convert outputs to binary predictions
        predictions = (outputs > 0.5).float()
        
        # Ensure the predictions are binary
        predictions_np = predictions.cpu().detach().numpy().astype(int)
        labels_np = labels.cpu().detach().numpy().astype(int)
        
        # Flatten arrays for sklearn metrics
        predictions_flat = predictions_np.flatten()
        labels_flat = labels_np.flatten()
        
        # Calculate metrics
        accuracy = accuracy_score(labels_flat, predictions_flat)
        precision = precision_score(labels_flat, predictions_flat, zero_division=1)
        recall = recall_score(labels_flat, predictions_flat, zero_division=1)
        f1 = f1_score(labels_flat, predictions_flat, zero_division=1)
        
        iou = calculate_iou(predictions_np, labels_np)
        
        # Append results to lists
        accuracies.append(accuracy)
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        ious.append(iou)

# Calculate average metrics
avg_accuracy = np.mean(accuracies)
avg_precision = np.mean(precisions)
avg_recall = np.mean(recalls)
avg_f1 = np.mean(f1_scores)
avg_iou = np.mean(ious)

print("Average Accuracy:", avg_accuracy)
print("Average Precision:", avg_precision)
print("Average Recall:", avg_recall)
print("Average F1 Score:", avg_f1)
print("Average IoU:", avg_iou)

Loading checkpoint '/kaggle/working/siamese_unet_best_model.pth'


  checkpoint = torch.load(filename)


Average Accuracy: 0.9847684237692093
Average Precision: 0.792368304563508
Average Recall: 0.7846505879993385
Average F1 Score: 0.7796273625231404
Average IoU: 0.65201917684659


In [22]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

# Assuming SiameseUNet and related classes are defined as in your provided code
# Define the model
model = SiameseUNet(in_channels=3, out_channels=1)

# Load the model checkpoint
model_save_path = '/kaggle/working/siamese_unet_best_model.pth'
load_checkpoint(model_save_path, model, optimizer=None)

# Set the model to evaluation mode
model.eval()
model.to(device)

Loading checkpoint '/kaggle/working/siamese_unet_best_model.pth'


  checkpoint = torch.load(filename)


SiameseUNet(
  (encoder1_1): EncoderBlock(
    (conv_block): ConvBlock(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder1_2): EncoderBlock(
    (conv_block): ConvBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, c

In [23]:
import torchvision.transforms as transforms
# Define the transformations (resize and convert to tensor)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [24]:
from PIL import Image


# Load the images
def load_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image.to(device)

# Example image paths
time1_image_path = '/kaggle/input/whu-rs-change-dataset/A/whu_12544_32256.png'
time2_image_path = '/kaggle/input/whu-rs-change-dataset/B/whu_12544_32256.png'
label_image_path='/kaggle/input/whu-rs-change-dataset/label/whu_12544_32256.png'
# Preprocess the images
time1_image = load_image(time1_image_path, transform)
time2_image = load_image(time2_image_path, transform)
label_image=load_image(label_image_path, transform)


In [25]:
# Make predictions
with torch.no_grad():
    output = model(time1_image, time2_image)

# Convert the output to a binary mask (if required)
prediction = (output > 0.5).float()

# To convert the tensor to a numpy array (for further processing or saving the output image)
prediction_np = prediction.squeeze().cpu().numpy()

# Optionally, save the output prediction as an image
import numpy as np
from PIL import Image

prediction_image = Image.fromarray((prediction_np * 255).astype(np.uint8))
prediction_image.save('/kaggle/working/prediction_image00878.png')
