In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torchvision.transforms.functional as F  # Add this line to import 'F'
import os

# Discriminator

In [None]:
class discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 64,kernel_size= 4, stride = 2, padding = 1 , padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=1,padding_mode='reflect'),
            nn.InstanceNorm2d(num_features = 128),
            nn.LeakyReLU(0.2)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=1,padding_mode='reflect'),
            nn.InstanceNorm2d(num_features = 256),
            nn.LeakyReLU(0.2)
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=1,padding=1,padding_mode='reflect'),
            nn.InstanceNorm2d(num_features = 512),
            nn.LeakyReLU(0.2)
        )
        self.block5 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4,stride=1,padding=1,padding_mode = 'reflect')
        
        
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return torch.sigmoid(x)


# Generator

In [None]:
class generator(nn.Module):
    def __init__(self, numResiduals):
        super().__init__()

        # Initial convolutional layer
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 64,kernel_size= 7, stride = 1, padding = 3 , padding_mode='reflect'),
            nn.ReLU(),
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(),
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(),
        )

        self.residuals = nn.Sequential()
        for i in range(numResiduals):
            self.residuals.add_module(f"residual_block_{i}", nn.Sequential(
                nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size=3, padding=1, padding_mode="reflect"),
                nn.InstanceNorm2d(256),
                nn.ReLU(),
                nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size=3, padding=1, padding_mode="reflect"),
                nn.InstanceNorm2d(256),
            ))

       
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(),
        )

       
        self.up3 = nn.Conv2d(in_channels = 64,out_channels = 3,kernel_size=7,stride=1,padding=3,padding_mode="reflect",)

    def forward(self, x):
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.residuals(x) + x  
        x = self.up1(x)
        x = self.up2(x)
        x =self.up3(x)
        return torch.tanh(x)


In [None]:
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.transforms.functional as F

from PIL import Image


class dataSetMaps(Dataset):
    def __init__(self,aerialDirectory,mapDirectory):
        self.aerialDirectory = aerialDirectory
        self.mapDirectory = mapDirectory
        self.transforms = transforms.Compose([
    transforms.Resize(size=(256, 256)),
    #transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
        
        
        self.aerialImages = os.listdir(aerialDirectory)
        self.mapImages  = os.listdir(mapDirectory)
        self.length = max(len(self.aerialImages), len(self.mapImages))
        self.aerialLength = len(self.aerialImages)
        self.mapLength = len(self.mapImages)
        
    def __len__(self):
        return self.length
    
    def __getitem__(self,index):
        
        mapImage = self.mapImages[index]
        aerialImage = self.aerialImages[index]
        
        
        mapPath = os.path.join(self.mapDirectory, mapImage)
        aerialPath = os.path.join(self.aerialDirectory, aerialImage)
        
        
        mapImage = (Image.open(mapPath).convert("RGB"))
        aerialImage = (Image.open(aerialPath).convert("RGB"))
        
        if torch.rand(1) < 0.5:  
            aerialImage  = F.hflip(aerialImage)
            mapImage  = F.hflip(mapImage)
            
        aerialImage  = self.transforms(aerialImage)
        mapImage  = self.transforms(mapImage)
        #aerialImage  = np.array(aerialImage)
        #mapImage  = np.array(mapImage)
        
        return  mapImage,aerialImage

In [None]:
import torch
import os

def trainingFunction(mapDiscriminator, aerialDiscriminator, mapGenerator, aerialGenerator, 
                     DiscriminatorOptimiser, GeneratorOptimiser, l1Loss, meanSquareError, 
                     dataSet, loader, cycleLambda, savePath, epoch, saveInterval=10):
    loop = tqdm(loader, leave=False)
    
    for index, (mapImage, aerialImage) in enumerate(loop):
        mapImage = mapImage.to(device)
        aerialImage = aerialImage.to(device)

        # Train Discriminators
        fakeMap = mapGenerator(mapImage)
        disRealMap = mapDiscriminator(mapImage)
        disFakeMap = mapDiscriminator(fakeMap.detach())
        disRealMapLoss = meanSquareError(disRealMap, torch.ones_like(disRealMap))
        disFakeMapLoss = meanSquareError(disFakeMap, torch.zeros_like(disFakeMap))
        disMapTotalLoss = disRealMapLoss + disFakeMapLoss

        fakeAerial = aerialGenerator(mapImage)
        disRealAerial = aerialDiscriminator(aerialImage)
        disFakeAerial = aerialDiscriminator(fakeAerial.detach())
        disRealAerialLoss = meanSquareError(disRealAerial, torch.ones_like(disRealAerial))
        disFakeAerialLoss = meanSquareError(disFakeAerial, torch.zeros_like(disFakeAerial))
        disAerialTotalLoss = disRealAerialLoss + disFakeAerialLoss

        discriminatorLoss = (disAerialTotalLoss + disMapTotalLoss) / 2

        DiscriminatorOptimiser.zero_grad()
        discriminatorLoss.backward()
        DiscriminatorOptimiser.step()

        # Train Generators
        disFakeMap = mapDiscriminator(fakeMap)
        disFakeAerial = aerialDiscriminator(fakeAerial)
        disRealAerialLoss = meanSquareError(disFakeAerial, torch.ones_like(disFakeAerial))
        disRealMapLoss = meanSquareError(disFakeMap, torch.ones_like(disFakeMap))

        cycleMap = mapGenerator(fakeAerial)
        cycleAerial = aerialGenerator(fakeMap)
        cycleMapLoss = l1Loss(mapImage, cycleMap)
        cycleAerialLoss = l1Loss(aerialImage, cycleAerial)

        generatorLoss = (disRealAerialLoss + disRealMapLoss +
                         (cycleMapLoss * cycleLambda) + (cycleAerialLoss * cycleLambda))
        
        GeneratorOptimiser.zero_grad()
        generatorLoss.backward()
        GeneratorOptimiser.step()

    if (epoch + 1) % 10 == 0:
        if not os.path.exists(savePath):
            os.makedirs(savePath)
        
        torch.save({
            'mapGenerator_state_dict': mapGenerator.state_dict(),
            'aerialGenerator_state_dict': aerialGenerator.state_dict(),
            'mapDiscriminator_state_dict': mapDiscriminator.state_dict(),
            'aerialDiscriminator_state_dict': aerialDiscriminator.state_dict(),
            'generator_optimizer_state_dict': GeneratorOptimiser.state_dict(),
            'discriminator_optimizer_state_dict': DiscriminatorOptimiser.state_dict(),
            'epoch': epoch
        }, os.path.join(savePath, f'model_epoch_{epoch}.pth'))
        
        print(f'Models saved at epoch {epoch}.')



In [None]:
mapDiscriminator = discriminator().to(device)
aerialDiscriminator = discriminator().to(device)

mapGenerator  = generator(9).to(device)
aerialGenerator = generator(9).to(device)

DiscriminatorOptimiser = optim.Adam(list(mapDiscriminator.parameters()) + list(aerialDiscriminator.parameters()), lr =  1e-5, betas = (0.5,0.99))

GeneratorOptimiser = optim.Adam(list(mapDiscriminator.parameters()) + list(aerialDiscriminator.parameters()), lr =  1e-5, betas = (0.5,0.99))

l1Loss = nn.L1Loss()
meanSquareError = nn.MSELoss()

dataSet = dataSetMaps('mapscycle/dataset/aerialTrain', 'mapscycle/dataset/mapTrain')
loader = DataLoader(dataSet, batch_size = 4, shuffle = True, num_workers = 2, pin_memory = True)

cycleLambda =10


In [None]:
epochs = 1
for i in range(epochs):
    
    trainingFunction(mapDiscriminator,aerialDiscriminator, mapGenerator, aerialGenerator, DiscriminatorOptimiser, GeneratorOptimiser, l1Loss, meanSquareError, dataSet, loader, cycleLambda,'modelVersions',10,i)


In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import torch
from sklearn.metrics import jaccard_score

dataSet = dataSetMaps('mapscycle/dataset/aerialVal', '/mapscycle/dataset/mapVal')
testLoader = DataLoader(dataSet, batch_size=1, shuffle=True, num_workers=1, pin_memory=True)

def testFunction(mapGenerator, aerialGenerator, testLoader, maxImages=5):
    mapGenerator.eval()
    aerialGenerator.eval()
    
    iouValuesMapToAerial = []
    iouValuesAerialToMap = []

    with torch.no_grad():
        for index, (mapImage, aerialImage) in enumerate(testLoader):
            mapImage = mapImage.to(device)
            aerialImage = aerialImage.to(device)

            generatedAerial = mapGenerator(mapImage)
            iouMapToAerial = calculateIou(generatedAerial, aerialImage)
            iouValuesMapToAerial.append(iouMapToAerial)

            generatedMap = aerialGenerator(aerialImage)
            iouAerialToMap = calculateIou(generatedMap, mapImage)
            iouValuesAerialToMap.append(iouAerialToMap)

   

    avgIouMapToAerial = sum(iouValuesMapToAerial) / len(iouValuesMapToAerial)
    avgIouAerialToMap = sum(iouValuesAerialToMap) / len(iouValuesAerialToMap)
    
    print(f"Average IoU (Map to Aerial): {avgIouMapToAerial}")
    print(f"Average IoU (Aerial to Map): {avgIouAerialToMap}")

def showImages(mapImage, aerialImage, generatedAerial, generatedMap):
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))

    mapImage = mapImage.cpu().squeeze(0).permute(1, 2, 0) * 0.5 + 0.5
    aerialImage = aerialImage.cpu().squeeze(0).permute(1, 2, 0) * 0.5 + 0.5
    generatedAerial = generatedAerial.cpu().squeeze(0).permute(1, 2, 0) * 0.5 + 0.5
    generatedMap = generatedMap.cpu().squeeze(0).permute(1, 2, 0) * 0.5 + 0.5

    axes[0, 0].imshow(mapImage)
    axes[0, 0].set_title('Original Map')
    axes[0, 0].axis('off')

 
    axes[0, 1].imshow(aerialImage)
    axes[0, 1].set_title('Original Aerial')
    axes[0, 1].axis('off')

    axes[0, 2].imshow(generatedAerial)
    axes[0, 2].set_title('Generated Aerial')
    axes[0, 2].axis('off')

    axes[1, 0].imshow(generatedMap)
    axes[1, 0].set_title('Generated Map')
    axes[1, 0].axis('off')

    plt.show()

def calculateIou(generated, real):
    threshold = 0.5
    generatedMask = (generated > threshold).int().cpu().numpy().flatten()
    realMask = (real > threshold).int().cpu().numpy().flatten()

    iou = jaccard_score(realMask, generatedMask, average='binary')
    return iou

testFunction(mapGenerator, aerialGenerator, testLoader, maxImages=5)
