In [None]:
import os
from PIL import Image
import torch
from torchvision import transforms
from tqdm import tqdm 

a_dir = "C:\\DewangData\\VSCode\\Endsem\\EndsemDL\\TUE-CD\\TUE\\a"       
b_dir = "C:\\DewangData\\VSCode\\Endsem\\EndsemDL\\TUE-CD\\TUE\\b"       
label_dir = "C:\\DewangData\\VSCode\\Endsem\\EndsemDL\\TUE-CD\\TUE\\label"  

pre_paths = sorted([os.path.join(a_dir, f) for f in os.listdir(a_dir)])
post_paths = sorted([os.path.join(b_dir, f) for f in os.listdir(b_dir)])
mask_paths = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir)])

img_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor() 
])

all_pre = []
all_post = []
all_masks = []

for pre_path, post_path, mask_path in tqdm(zip(pre_paths, post_paths, mask_paths), total=len(pre_paths)):
    pre_img = Image.open(pre_path).convert("RGB")
    post_img = Image.open(post_path).convert("RGB")
    mask_img = Image.open(mask_path).convert("L") 

    pre_tensor = img_transform(pre_img)
    post_tensor = img_transform(post_img)
    mask_tensor = mask_transform(mask_img)

    mask_tensor = (mask_tensor > 0).float()

    all_pre.append(pre_tensor)
    all_post.append(post_tensor)
    all_masks.append(mask_tensor)

all_pre = torch.stack(all_pre)   
all_post = torch.stack(all_post) 
all_masks = torch.stack(all_masks)  

print("Pre shape:", all_pre.shape)
print("Post shape:", all_post.shape)
print("Mask shape:", all_masks.shape)

torch.save(all_pre, "pre_images.pt")
torch.save(all_post, "post_images.pt")
torch.save(all_masks, "masks.pt")

print("✅ Dataset preprocessing complete and saved as .pt files")
