In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim

In [3]:
data_path = "../data/input/0_datasets/L1_7/train_B"
results_path = "../data/results/L1_7"

# Data

In [4]:
# Data loading and transformation
transform = transforms.Compose([
    transforms.ToTensor()
])

# Create a dataset for loading images
image_dataset = datasets.ImageFolder(root=data_path, transform=transform)

# Hyperparameters
batch_size = 64
learning_rate = 0.0002
num_epochs = 50

# Create a DataLoader for the dataset
train_loader = DataLoader(image_dataset, batch_size=batch_size, shuffle=True)

In [5]:
image_size = 3 * 1024 * 2048  # Total number of elements in the image

# Model

In [6]:
import torch

In [7]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(100, 512 * 8 * 8),
            nn.ReLU(True)
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x).view(-1, 512, 8, 8)
        return self.conv(x)

# Initialize the generator
generator = Generator()

In [8]:
image, _ = next(iter(train_loader))
output = generator(image)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (168960x2048 and 100x32768)

In [None]:
# Training loop
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # Flatten the images
        images = images.view(images.size(0), -1)

        # Create labels
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # Train the discriminator
        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(batch_size, 100)
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # Train the generator
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

        if (i+1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}, D(x): {real_score.mean().item()}, D(G(z)): {fake_score.mean().item()}')

# Save the models
torch.save(generator.state_dict(), f'{results_path}/generator.pth')
torch.save(discriminator.state_dict(), f'{results_path}/discriminator.pth')