In [None]:
import matplotlib.pyplot as plt 
import torch 
import torch.nn as nn
from torchsummary import summary
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
from dataloader import WHU_bldg
from Model import Unet
from engine import train_one_epoch

In [None]:
# HYPERPARAMETERS

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:' , device)

BATCH_SIZE = 8
EPOCHS = 30
lr = 0.001
out_size = 324

In [None]:
def unnormalize(tensor):
    
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225] 
    mean = torch.tensor(mean).reshape(-1, 1, 1)
    std = torch.tensor(std).reshape(-1, 1, 1)
    tensor = tensor * std + mean
    return tensor

In [None]:
dataset_dir = "/dataset/deepglobe-2018-road-extraction"

transform_rgb = transforms.Compose([
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float32),  # Converts the tensor to float and scales to [0.0, 1.0]
    transforms.Resize((512 , 512)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Example normalization
])

transform_grey = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Resize((out_size, out_size), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ConvertImageDtype(torch.float32)  # Ensures masks are also floats if needed for specific loss calculations
])


In [None]:
train = WHU_bldg(parent_dir = "dataset" , set = "train" , transform_rgb= transform_rgb , transform_grey = transform_grey)
val = WHU_bldg(parent_dir = "dataset" , set = "val" , transform_rgb= transform_rgb , transform_grey = transform_grey)
test = WHU_bldg(parent_dir = "dataset" , set = "test" , transform_rgb= transform_rgb , transform_grey = transform_grey)

Display a sample

In [None]:
image , mask = train[10]
image = unnormalize(image)

image = image.permute(1 , 2, 0).numpy()
mask = mask.squeeze().numpy()
image.shape , mask.shape

In [None]:
plt.figure(figsize=(10,5))

plt.subplot(1 , 2 , 1)
plt.imshow(image)
plt.title('Image')
plt.axis('off')

plt.subplot(1,2,2)
plt.imshow(mask , cmap = 'gray')
plt.title('gt_mask')
plt.axis('off')

plt.show()

In [None]:
train_loader = DataLoader(dataset = train , batch_size = BATCH_SIZE , shuffle= True)
val_loader = DataLoader(dataset = val , batch_size = BATCH_SIZE , shuffle = True)
test_loader = DataLoader(dataset = test , shuffle = False)

len(train_loader) , len(val_loader) # batches

In [None]:
model = Unet().to(device)

summary(model , input_size=( 3 , 512 , 512))


In [None]:
metrics = train_one_epoch(model = model , 
                          train_loader = train_loader , 
                          val_loader = val_loader , 
                          epochs  = EPOCHS , 
                          lr = lr , 
                          scheduler = 'exponential' , 
                          out_dir = 'weights' , 
                          device = device)