In [33]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from PIL import Image
import os

class ImageDataset(Dataset):
    def __init__(self, root_X, root_Y, transforms=None):
        self.transform = transforms
        self.files_X = sorted([os.path.join(root_X, f) for f in os.listdir(root_X) if f.endswith('.jpg') or f.endswith('.png')])
        self.files_Y = sorted([os.path.join(root_Y, f) for f in os.listdir(root_Y) if f.endswith('.jpg') or f.endswith('.png')])

    def __getitem__(self, index):
        image_X = Image.open(self.files_X[index % len(self.files_X)])
        image_Y = Image.open(self.files_Y[index % len(self.files_Y)])

        if self.transform:
            image_X = self.transform(image_X)
            image_Y = self.transform(image_Y)

        return image_X, image_Y

    def __len__(self):
        return max(len(self.files_X), len(self.files_Y))

In [34]:
transform = transforms.Compose([
    transforms.Resize((256, 256), Image.BICUBIC),
    transforms.RandomCrop((256, 256)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])


In [4]:
!unzip /content/images.zip

In [35]:
root_X = '/content/Semantic segmentation dataset/Tile_1/images'
root_Y = '/content/Semantic segmentation dataset/Tile_1/masks'

# Create the dataset
dataset = ImageDataset(root_X=root_X, root_Y=root_Y, transforms=transform)

# Create the dataloader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [36]:
sum(data.size(0) for data, mask in dataloader)

9

In [48]:
import torch
import torch.nn as nn

# Define the Residual Block for the Generator
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, kernel_size=3, padding=1, padding_mode='reflect'),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        return x + self.block(x)

# Define the Generator Model
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, num_residual_blocks=9):
        super(Generator, self).__init__()
        # Initial convolution block
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3, padding_mode='reflect'),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(num_residual_blocks):
            model.append(ResidualBlock(in_features))

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
            nn.Conv2d(64, output_channels, kernel_size=7, padding=3, padding_mode='reflect'),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Define the Discriminator Model
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        model += [
            nn.Conv2d(256, 512, kernel_size=4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        # Patch output
        model += [nn.Conv2d(512, 1, kernel_size=4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [49]:
# Losses
class GANLoss(nn.Module):
    def __init__(self):
        super(GANLoss, self).__init__()
        self.loss = nn.MSELoss()

    def forward(self, fake, real):
        real_loss = self.loss(fake, real)
        return real_loss

class CycleConsistencyLoss(nn.Module):
    def __init__(self):
        super(CycleConsistencyLoss, self).__init__()
        self.loss = nn.L1Loss()

    def forward(self, reconstructed, original):
        return self.loss(reconstructed, original)

In [50]:
def weight_init(m):
    classname = m.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Assuming we have defined everything needed and have data loaders for domain X and Y

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
G_XtoY = Generator(3, 3).to(device)
G_YtoX = Generator(3, 3).to(device)
D_X = Discriminator(3).to(device)
D_Y = Discriminator(3).to(device)

G_XtoY.apply(weight_init)
G_YtoX.apply(weight_init)
D_X.apply(weight_init)
D_Y.apply(weight_init)

# Optimizers
optimizer_G = torch.optim.Adam(list(G_XtoY.parameters()) + list(G_YtoX.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_X = torch.optim.Adam(D_X.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_Y = torch.optim.Adam(D_Y.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Losses
adversarial_loss = GANLoss().to(device)
cycle_consistency_loss = CycleConsistencyLoss().to(device)

# Training Loop
num_epochs = 2000
for epoch in range(num_epochs):
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)

        # Train Generators G_XtoY & G_YtoX
        optimizer_G.zero_grad()

        # Identity loss
        # (omitted for simplicity, you can add this to improve performance)

        # GAN Loss
        fake_y = G_XtoY(x)
        pred_fake_y = D_Y(fake_y)
        loss_GAN_XtoY = adversarial_loss(pred_fake_y, torch.ones_like(pred_fake_y))

        fake_x = G_YtoX(y)
        pred_fake_x = D_X(fake_x)
        loss_GAN_YtoX = adversarial_loss(pred_fake_x, torch.ones_like(pred_fake_x))

        # Cycle consistency loss
        reconstructed_x = G_YtoX(fake_y)
        loss_cycle_X = cycle_consistency_loss(reconstructed_x, x)

        reconstructed_y = G_XtoY(fake_x)
        loss_cycle_Y = cycle_consistency_loss(reconstructed_y, y)

        # Total loss
        total_loss_G = loss_GAN_XtoY + loss_GAN_YtoX + loss_cycle_X * 10.0 + loss_cycle_Y * 10.0
        total_loss_G.backward()
        optimizer_G.step()

        # Train Discriminator D_X
        optimizer_D_X.zero_grad()

        pred_real_x = D_X(x)
        loss_D_real = adversarial_loss(pred_real_x, torch.ones_like(pred_real_x))

        pred_fake_x = D_X(fake_x.detach())
        loss_D_fake = adversarial_loss(pred_fake_x, torch.zeros_like(pred_fake_x))

        # Total Discriminator X loss
        total_loss_D_X = (loss_D_real + loss_D_fake) * 0.5
        total_loss_D_X.backward()
        optimizer_D_X.step()

        # Train Discriminator D_Y
        optimizer_D_Y.zero_grad()

        pred_real_y = D_Y(y)
        loss_D_real_y = adversarial_loss(pred_real_y, torch.ones_like(pred_real_y))

        pred_fake_y = D_Y(fake_y.detach())
        loss_D_fake_y = adversarial_loss(pred_fake_y, torch.zeros_like(pred_fake_y))

        # Total Discriminator Y loss
        total_loss_D_Y = (loss_D_real_y + loss_D_fake_y) * 0.5
        total_loss_D_Y.backward()
        optimizer_D_Y.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss_G: {total_loss_G.item()}, Loss_D_X: {total_loss_D_X.item()}, Loss_D_Y: {total_loss_D_Y.item()}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch [810/2000], Loss_G: 3.1854496002197266, Loss_D_X: 0.02279803343117237, Loss_D_Y: 0.04254557564854622
Epoch [810/2000], Loss_G: 3.0247390270233154, Loss_D_X: 0.017173990607261658, Loss_D_Y: 0.09722720086574554
Epoch [810/2000], Loss_G: 3.558319568634033, Loss_D_X: 0.022475363686680794, Loss_D_Y: 0.09877438843250275
Epoch [810/2000], Loss_G: 2.7616770267486572, Loss_D_X: 0.022703317925333977, Loss_D_Y: 0.06339096277952194
Epoch [810/2000], Loss_G: 2.938981771469116, Loss_D_X: 0.014597882516682148, Loss_D_Y: 0.06427156925201416
Epoch [810/2000], Loss_G: 3.496487855911255, Loss_D_X: 0.051931291818618774, Loss_D_Y: 0.11308109760284424
Epoch [811/2000], Loss_G: 3.1593685150146484, Loss_D_X: 0.027964439243078232, Loss_D_Y: 0.08386798202991486
Epoch [811/2000], Loss_G: 3.079641103744507, Loss_D_X: 0.027787018567323685, Loss_D_Y: 0.03735942393541336
Epoch [811/2000], Loss_G: 3.0233447551727295, Loss_D_X: 0.017302358523011208

In [1]:
import matplotlib.pyplot as plt

def show_transformed_images(test_loader, generator_XtoY, generator_YtoX, device, num_images=5):
    fig, axes = plt.subplots(nrows=num_images, ncols=4, figsize=(15, num_images * 5))  # Adjusted ncols to 4 to match the number of images

    for i, (x, y) in enumerate(test_loader):
        if i >= num_images:
            break

        with torch.no_grad():  # No need to track gradients
            x, y = x.to(device), y.to(device)  # Ensure both x and y are moved to the correct device
            fake_y = generator_XtoY(x).detach()  # Fixed generator name
            fake_x = generator_YtoX(y).detach()  # Fixed generator name

        # Show original images and generated images
        images = [x[0].cpu(), fake_y[0].cpu(), y[0].cpu(), fake_x[0].cpu()]
        titles = ['Original X', 'Fake Y', 'Original Y', 'Fake X']

        for col, img in enumerate(images):
            ax = axes[i, col]
            # Normalize the images to [0, 1] from [-1, 1] for display
            img = img.permute(1, 2, 0).numpy() * 0.5 + 0.5  # Ensure the tensor is on CPU for plotting
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(titles[col])

    plt.tight_layout()
    plt.show()


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G_XtoY.to(device)
G_YtoX.to(device)

show_transformed_images(dataloader, G_XtoY, G_YtoX, num_images=5, device=device)