In [184]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [185]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim
import PIL

from sklearn.model_selection import train_test_split
from tqdm import tqdm, trange

In [186]:
data_path = "../data/input/0_datasets/L1_7/train_B/class_0"
results_path = "../data/results/L1_7"

# Data

In [187]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.data = os.listdir(data_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img = PIL.Image.open(os.path.join(self.data_path, self.data[idx]))
        if self.transform:
            img = self.transform(img)
        return img

In [188]:
# Data loading and transformation, like flipping and adding noise
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
])


# Create a dataset for loading images
image_dataset = Dataset(data_path, transform=transform)

In [None]:
from sklearn.model_selection import train_test_split

# Assuming image_dataset is already defined and loaded with images
# Split the dataset into training and test sets (75% train, 25% test)
train_data, test_data = train_test_split(image_dataset, test_size=0.25)

# Further split the training data into training and validation sets (75% train, 25% validation)
train_data, val_data = train_test_split(train_data, test_size=0.25)

# Print the sizes of the datasets
print(f"Training data size: {len(train_data)}")
print(f"Validation data size: {len(val_data)}")
print(f"Test data size: {len(test_data)}")

# Create a DataLoader for each dataset
batch_size = 5
train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size, shuffle=False)

print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in val_loader: {len(val_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")

In [177]:
image_size = 3 * 1024 * 2048  # Total number of elements in the image

# Model

In [178]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [179]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1*29*61, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

In [180]:
# Initialize the models
generator = Generator()
discriminator = Discriminator()

In [183]:
X = torch.randn(1, 3, 1024, 2048)
out = generator(X)
print(f"Generator shape: {out.shape}")
out = discriminator(out)
print(f"Discriminator shape: {out.shape}")
print(f"Discriminator output: {out}")   
loss_fn = nn.BCELoss()
loss = loss_fn(out, torch.zeros(1, 1))
print(f"Loss: {loss}")

Generator shape: torch.Size([1, 3, 1024, 2048])
Discriminator shape: torch.Size([1, 1])
Discriminator output: tensor([[0.5683]], grad_fn=<SigmoidBackward0>)
Loss: 0.8400188684463501


# Model

## Training

In [165]:
def validate(generator, discriminator, val_loader, criterion, device):
    generator.eval()
    discriminator.eval()
    
    val_loss = 0.0
    with torch.no_grad():
        for images in val_loader:
            images = images.to(device)
            batch_size = images.size(0)
            
            # Create labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Validate the discriminator with real images
            outputs = discriminator(images)
            d_loss_real = criterion(outputs, real_labels)

            # Generate fake images using random noise as input
            z = torch.randn(batch_size, images.size(1), images.size(2), images.size(3)).to(device)
            fake_images = generator(z)
            outputs = discriminator(fake_images)
            d_loss_fake = criterion(outputs, fake_labels)

            # Compute the total validation loss
            d_loss = d_loss_real + d_loss_fake
            val_loss += d_loss.item()

    return val_loss / len(val_loader)

In [166]:
def train(generator, discriminator, train_loader, val_loader, num_epochs, criterion, optimizer_g, optimizer_d, device):
    train_losses = []
    val_losses = []

    generator.to(device)
    discriminator.to(device)

    for epoch in trange(num_epochs, desc='Epoch Loop', unit='epoch'):
        generator.train()
        discriminator.train()
        
        train_loss = 0.0
        for i, images in enumerate(tqdm(train_loader, desc='Batch Loop', leave=False)):
            images = images.to(device)
            batch_size = images.size(0)
            
            # Create labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Train the discriminator with real images
            outputs = discriminator(images)
            d_loss_real = criterion(outputs, real_labels)
            real_score = outputs

            # Generate fake images using random noise as input
            z = torch.randn(batch_size, images.size(1), images.size(2), images.size(3)).to(device)
            fake_images = generator(z)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            fake_score = outputs

            # Compute the total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            optimizer_d.zero_grad()
            d_loss.backward()
            optimizer_d.step()

            # Train the generator
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)

            optimizer_g.zero_grad()
            g_loss.backward()
            optimizer_g.step()

            train_loss += d_loss.item() + g_loss.item()

        train_losses.append(train_loss / len(train_loader))

        # Validation
        val_loss = validate(generator, discriminator, val_loader, criterion, device)
        val_losses.append(val_loss)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss / len(train_loader)}, Val Loss: {val_loss}')

    return train_losses, val_losses

In [167]:
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Losses')
    plt.show()

In [168]:
# Assuming train_loader and val_loader are already defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

# Initialize the generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training parameters
num_epochs = 100


Device: cuda


In [169]:
# Train the GAN
train_losses, val_losses = train(generator, discriminator, train_loader, val_loader, num_epochs, criterion, optimizer_g, optimizer_d, device)


Epoch Loop:   0%|          | 0/100 [00:31<?, ?epoch/s]


KeyboardInterrupt: 

In [None]:
plot_losses(train_losses, val_losses)