In [65]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [66]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim
import PIL

from sklearn.model_selection import train_test_split
from tqdm import tqdm, trange

In [67]:
data_path = "../data/input/0_datasets/L1_7/train_B/class_0"
results_path = "../data/results/L1_7"

# Data

In [68]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.data = os.listdir(data_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img = PIL.Image.open(os.path.join(self.data_path, self.data[idx]))
        if self.transform:
            img = self.transform(img)
        return img

In [69]:
# Data loading and transformation, like flipping and adding noise
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
])


# Create a dataset for loading images
image_dataset = Dataset(data_path, transform=transform)

In [70]:
from sklearn.model_selection import train_test_split

# Assuming image_dataset is already defined and loaded with images
# Split the dataset into training and test sets (75% train, 25% test)
train_data, test_data = train_test_split(image_dataset, test_size=0.25)

# Further split the training data into training and validation sets (75% train, 25% validation)
train_data, val_data = train_test_split(train_data, test_size=0.25)

# Print the sizes of the datasets
print(f"Training data size: {len(train_data)}")
print(f"Validation data size: {len(val_data)}")
print(f"Test data size: {len(test_data)}")

# Create a DataLoader for each dataset
batch_size = 5
train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size, shuffle=False)

print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in val_loader: {len(val_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")

Training data size: 30
Validation data size: 11
Test data size: 14
Number of batches in train_loader: 6
Number of batches in val_loader: 3
Number of batches in test_loader: 3


In [71]:
image_size = 3 * 1024 * 2048  # Total number of elements in the image

# Model

## Pix2Pix

In [72]:
import torch
import torch.nn as nn

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNetGenerator, self).__init__()
        
        def conv_block(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        def upconv_block(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
                nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.encoder4 = conv_block(256, 512)
        self.encoder5 = conv_block(512, 1024)
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.upconv4 = upconv_block(1024, 512)
        self.upconv3 = upconv_block(1024, 256)
        self.upconv2 = upconv_block(512, 128)
        self.upconv1 = upconv_block(256, 64)  # Corrected input channels
        
        self.final_conv = nn.Conv2d(128, out_channels, kernel_size=1)  # Corrected input channels
        
    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(self.pool(enc2))
        enc4 = self.encoder4(self.pool(enc3))
        enc5 = self.encoder5(self.pool(enc4))
        
        dec4 = self.upconv4(enc5)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        
        return self.final_conv(dec1)

In [73]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(PatchGANDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels * 2, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        return self.model(x)


In [74]:
# Test the fixed code
X = torch.rand(1, 3, 1024, 2048)
generator = UNetGenerator()
discriminator = PatchGANDiscriminator()

fake_image = generator(X)
discriminator_output = discriminator(fake_image, X)
print(f"Generator output shape: {fake_image.shape}")
print(f"Discriminator output shape: {discriminator_output.shape}")

Generator output shape: torch.Size([1, 3, 1024, 2048])
Discriminator output shape: torch.Size([1, 1, 126, 254])


## GAN

In [75]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [76]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1*29*61, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

In [77]:
# Initialize the models
generator = Generator()
discriminator = Discriminator()

In [78]:
X = torch.randn(1, 3, 1024, 2048)
out = generator(X)
print(f"Generator shape: {out.shape}")
out = discriminator(out)
print(f"Discriminator shape: {out.shape}")
print(f"Discriminator output: {out}")   
loss_fn = nn.BCELoss()
loss = loss_fn(out, torch.zeros(1, 1))
print(f"Loss: {loss}")

Generator shape: torch.Size([1, 3, 1024, 2048])
Discriminator shape: torch.Size([1, 1])
Discriminator output: tensor([[0.5016]], grad_fn=<SigmoidBackward0>)
Loss: 0.6964022517204285


# Model

## Training

In [79]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
import matplotlib.pyplot as plt

# Assuming UNetGenerator and PatchGANDiscriminator are already defined

def train(generator, discriminator, train_loader, val_loader, num_epochs=50, lr=0.0002):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    criterion_GAN = nn.BCEWithLogitsLoss()
    criterion_L1 = nn.L1Loss()

    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

    train_losses_G = []
    train_losses_D = []
    val_losses_G = []
    val_losses_D = []

    for epoch in trange(num_epochs, desc="Epochs"):
        generator.train()
        discriminator.train()
        
        epoch_loss_G = 0
        epoch_loss_D = 0

        for real_images in train_loader:
            real_images = real_images.to(device)
            fake_images = generator(real_images)

            # Train Discriminator
            optimizer_D.zero_grad()
            real_labels = torch.ones(real_images.size(0), 1, 126, 254).to(device)
            fake_labels = torch.zeros(real_images.size(0), 1, 126, 254).to(device)

            real_output = discriminator(real_images, real_images)
            fake_output = discriminator(fake_images.detach(), real_images)

            loss_D_real = criterion_GAN(real_output, real_labels)
            loss_D_fake = criterion_GAN(fake_output, fake_labels)
            loss_D = (loss_D_real + loss_D_fake) / 2
            loss_D.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            fake_output = discriminator(fake_images, real_images)
            loss_G_GAN = criterion_GAN(fake_output, real_labels)
            loss_G_L1 = criterion_L1(fake_images, real_images)
            loss_G = loss_G_GAN + 100 * loss_G_L1
            loss_G.backward()
            optimizer_G.step()

            epoch_loss_G += loss_G.item()
            epoch_loss_D += loss_D.item()

        train_losses_G.append(epoch_loss_G / len(train_loader))
        train_losses_D.append(epoch_loss_D / len(train_loader))

        val_loss_G, val_loss_D = validate(generator, discriminator, val_loader, criterion_GAN, criterion_L1, device)
        val_losses_G.append(val_loss_G)
        val_losses_D.append(val_loss_D)

    plot_losses(train_losses_G, train_losses_D, val_losses_G, val_losses_D)


In [80]:

def validate(generator, discriminator, val_loader, criterion_GAN, criterion_L1, device):
    generator.eval()
    discriminator.eval()

    val_loss_G = 0
    val_loss_D = 0

    with torch.no_grad():
        for real_images in val_loader:
            real_images = real_images.to(device)
            fake_images = generator(real_images)

            real_labels = torch.ones(real_images.size(0), 1, 126, 254).to(device)
            fake_labels = torch.zeros(real_images.size(0), 1, 126, 254).to(device)

            real_output = discriminator(real_images, real_images)
            fake_output = discriminator(fake_images, real_images)

            loss_D_real = criterion_GAN(real_output, real_labels)
            loss_D_fake = criterion_GAN(fake_output, fake_labels)
            loss_D = (loss_D_real + loss_D_fake) / 2

            fake_output = discriminator(fake_images, real_images)
            loss_G_GAN = criterion_GAN(fake_output, real_labels)
            loss_G_L1 = criterion_L1(fake_images, real_images)
            loss_G = loss_G_GAN + 100 * loss_G_L1

            val_loss_G += loss_G.item()
            val_loss_D += loss_D.item()

    return val_loss_G / len(val_loader), val_loss_D / len(val_loader)


In [81]:

def test(generator, discriminator, test_loader, criterion_GAN, criterion_L1, device):
    generator.eval()
    discriminator.eval()

    test_loss_G = 0
    test_loss_D = 0

    with torch.no_grad():
        for real_images in test_loader:
            real_images = real_images.to(device)
            fake_images = generator(real_images)

            real_labels = torch.ones(real_images.size(0), 1, 126, 254).to(device)
            fake_labels = torch.zeros(real_images.size(0), 1, 126, 254).to(device)

            real_output = discriminator(real_images, real_images)
            fake_output = discriminator(fake_images, real_images)

            loss_D_real = criterion_GAN(real_output, real_labels)
            loss_D_fake = criterion_GAN(fake_output, fake_labels)
            loss_D = (loss_D_real + loss_D_fake) / 2

            fake_output = discriminator(fake_images, real_images)
            loss_G_GAN = criterion_GAN(fake_output, real_labels)
            loss_G_L1 = criterion_L1(fake_images, real_images)
            loss_G = loss_G_GAN + 100 * loss_G_L1

            test_loss_G += loss_G.item()
            test_loss_D += loss_D.item()

    return test_loss_G / len(test_loader), test_loss_D / len(test_loader)


In [82]:

def plot_losses(train_losses_G, train_losses_D, val_losses_G, val_losses_D):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses_G, label='Train Generator Loss')
    plt.plot(train_losses_D, label='Train Discriminator Loss')
    plt.plot(val_losses_G, label='Validation Generator Loss')
    plt.plot(val_losses_D, label='Validation Discriminator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


In [83]:
# Example usage
# Assuming train_loader, val_loader, and test_loader are defined
generator = UNetGenerator()
discriminator = PatchGANDiscriminator()

train(generator, discriminator, train_loader, val_loader, num_epochs=50, lr=0.0002)
test_loss_G, test_loss_D = test(generator, discriminator, test_loader, nn.BCEWithLogitsLoss(), nn.L1Loss(), torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print(f"Test Generator Loss: {test_loss_G}, Test Discriminator Loss: {test_loss_D}")

Epochs:   0%|          | 0/50 [00:30<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB. GPU 0 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 14.00 GiB is allocated by PyTorch, and 259.10 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [198]:
# Assuming train_loader and val_loader are already defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Initialize the generator and discriminator
generator = UNetGenerator().to(device)
discriminator = PatchGANDiscriminator().to(device)

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training parameters
num_epochs = 100


Device: cuda


In [None]:
# Train the GAN
train_losses, val_losses = train(generator, discriminator, train_loader, val_loader, num_epochs, criterion, optimizer_g, optimizer_d, device)


Epoch Loop:   1%|          | 1/100 [04:51<8:00:51, 291.43s/epoch]

Epoch [1/100], Train Loss: 7.8772909839948015, Val Loss: 0.9880637923876444


Epoch Loop:   2%|▏         | 2/100 [07:24<5:42:48, 209.88s/epoch]

Epoch [2/100], Train Loss: 4.923818312585354, Val Loss: 7.303421815236409


Epoch Loop:   3%|▎         | 3/100 [10:00<5:00:00, 185.57s/epoch]

Epoch [3/100], Train Loss: 5.260098261758685, Val Loss: 9.70291010538737


Epoch Loop:   4%|▍         | 4/100 [12:34<4:36:50, 173.03s/epoch]

Epoch [4/100], Train Loss: 5.417339256498963, Val Loss: 7.1079104741414385


Epoch Loop:   5%|▌         | 5/100 [15:07<4:22:31, 165.81s/epoch]

Epoch [5/100], Train Loss: 5.911184047038357, Val Loss: 4.077771345774333


Epoch Loop:   6%|▌         | 6/100 [17:40<4:12:50, 161.39s/epoch]

Epoch [6/100], Train Loss: 6.28053996572271, Val Loss: 2.5712745984395347


Epoch Loop:   7%|▋         | 7/100 [20:13<4:05:46, 158.57s/epoch]

Epoch [7/100], Train Loss: 6.604667161901792, Val Loss: 0.5344867606957754


Epoch Loop:   8%|▊         | 8/100 [22:45<4:00:09, 156.62s/epoch]

Epoch [8/100], Train Loss: 7.017241929890588, Val Loss: 0.08026164521773656


Epoch Loop:   9%|▉         | 9/100 [25:18<3:55:35, 155.33s/epoch]

Epoch [9/100], Train Loss: 7.172932996802653, Val Loss: 0.0009890616638585925


Epoch Loop:  10%|█         | 10/100 [27:50<3:51:38, 154.43s/epoch]

Epoch [10/100], Train Loss: 7.328508939205979, Val Loss: 0.0003529550352444251


Epoch Loop:  11%|█         | 11/100 [30:23<3:48:21, 153.95s/epoch]

Epoch [11/100], Train Loss: 7.346533896207499, Val Loss: 0.0009027215031286081


Epoch Loop:  12%|█▏        | 12/100 [32:56<3:45:31, 153.76s/epoch]

Epoch [12/100], Train Loss: 7.4212182134700315, Val Loss: 0.013073699859281382


Epoch Loop:  13%|█▎        | 13/100 [35:29<3:42:40, 153.57s/epoch]

Epoch [13/100], Train Loss: 7.672696263063699, Val Loss: 0.07588412861029308


Epoch Loop:  14%|█▍        | 14/100 [38:06<3:41:17, 154.39s/epoch]

Epoch [14/100], Train Loss: 7.858853703248315, Val Loss: 0.018446484580636024




In [None]:
plot_losses(train_losses, val_losses)