In [183]:
# Imports
import argparse

import torch
from torch import nn, optim
from torch.autograd.variable import Variable

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

from PIL import Image, ImageDraw, ImageFont
from torchvision.utils import save_image
import os

In [184]:
# Hyperparameters
latent_dim = 100  # Size of the noise vector
num_classes = 3  # Number of labels (rock, paper, scissors)
image_size = 170  # Image resolution
batch_size = 170  # Batch size
num_epochs = 200  # Number of training epochs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available


In [185]:
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=50, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=16, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: beta 1')
parser.add_argument('--b2', type=float, default=0.999, help='adam: beta 2')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--img_size_x', type=int, default=256, help='size of each image dimension')
parser.add_argument('--img_size_y', type=int, default=341, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
parser.add_argument('--n_classes', type=int, default=10, help='number of classes (e.g., digits 0 ..9, 10 classes on mnist)')
parser.add_argument('--display_port', type=int, default=8097, help='where to run the visdom for visualization? useful if running multiple visdom tabs')
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
parser.add_argument('--sample_interval', type=int, default=5, help='interval between image samples')

# Prevent argparse from interpreting Jupyter arguments
opt = parser.parse_args(args=[])

# Print the arguments for verification
print(opt)


Namespace(n_epochs=50, batch_size=16, lr=0.0002, b1=0.5, b2=0.999, latent_dim=100, img_size_x=256, img_size_y=341, channels=3, n_classes=10, display_port=8097, display_server='http://localhost', sample_interval=5)


In [186]:
try:
    import visdom
    vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, raise_exceptions=True) # Create vis env.
except ImportError:
    vis = None
else:
    assert vis.check_connection(timeout_seconds=3), "No connection could be formed quickly"

Setting up a new session...


In [187]:
img_dims = (opt.channels, opt.img_size_x, opt.img_size_y)
n_features = opt.channels * opt.img_size_x * opt.img_size_y

In [188]:
def weights_init_xavier(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.xavier_normal_(m.weight.data, gain=0.02)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight.data, gain=0.02)
    elif classname.find('BatchNorm1d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)


In [189]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # Map z & y (noise and label) into the hidden layer.
        # TO DO: How to run this with a function defined here?
        self.z_map = nn.Sequential(
            nn.Linear(opt.latent_dim, 200),
            nn.BatchNorm1d(200),
            nn.ReLU(inplace=True),
        )
        self.y_map = nn.Sequential(
            nn.Linear(opt.n_classes, 1000),
            nn.BatchNorm1d(1000),
            nn.ReLU(inplace=True),
        )
        self.zy_map = nn.Sequential(
            nn.Linear(1200, 1200),
            nn.BatchNorm1d(1200),
            nn.ReLU(inplace=True),
        )

        self.model = nn.Sequential(
            nn.Linear(1200, n_features),
            nn.Tanh()
        )
        # Tanh > Image values are between [-1, 1]


    def forward(self, z, y):
        zh = self.z_map(z)
        yh = self.y_map(y)
        zy = torch.cat((zh, yh), dim=1) # Combine noise and labels.
        zyh = self.zy_map(zy)
        x = self.model(zyh)
        x = x.view(x.size(0), *img_dims)
        return x

In [190]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(240, 1),
            nn.Sigmoid()
        )

        # Imitating a 3d array by combining second and third dimensions via multiplication for maxout.
        self.x_map = nn.Sequential(nn.Linear(n_features, 240 * 5))
        self.y_map = nn.Sequential(nn.Linear(opt.n_classes, 50 * 5))
        self.j_map = nn.Sequential(nn.Linear(240 + 50, 240 * 4))

    def forward(self, x, y):
        # maxout for x
        print(x.shape)
        x = x.view(-1, n_features)
        x = self.x_map(x)
        x, _ = x.view(-1, 240, 5).max(dim=2) # pytorch outputs max values and indices
        # .. and y
        y = y.view(-1, opt.n_classes)
        y = self.y_map(y)
        y, _ = y.view(-1, 50, 5).max(dim=2)
        # joint maxout layer
        jmx = torch.cat((x, y), dim=1)
        jmx = self.j_map(jmx)
        jmx, _ = jmx.view(-1, 240, 4).max(dim=2)

        prob = self.model(jmx)
        return prob


In [191]:
# Loading local dataset
transform = transforms.Compose([
    transforms.Resize((opt.img_size_x,opt.img_size_y)),  # Resize to the required image size
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))  # Normalize image to [-1, 1]
])

image_dir = './images'  # Path to your image dataset

# Load images using ImageFolder
dataset = datasets.ImageFolder(root=image_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)

# Check image labels
print(dataset.classes)

# Check amount of images loaded
print(len(dataset))

['Paper', 'Rock', 'Scissor']
133


In [192]:
cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
gan_loss = nn.BCELoss()

generator = Generator()
discriminator = Discriminator()

optimizer_D = optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_G = optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


# Loss record.
g_losses = []
d_losses = []
epochs = []
loss_legend = ['Discriminator', 'Generator']

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()

# Weight initialization.
generator.apply(weights_init_xavier)
discriminator.apply(weights_init_xavier)


Discriminator(
  (model): Sequential(
    (0): Linear(in_features=240, out_features=1, bias=True)
    (1): Sigmoid()
  )
  (x_map): Sequential(
    (0): Linear(in_features=261888, out_features=1200, bias=True)
  )
  (y_map): Sequential(
    (0): Linear(in_features=10, out_features=250, bias=True)
  )
  (j_map): Sequential(
    (0): Linear(in_features=290, out_features=960, bias=True)
  )
)

In [193]:
for epoch in range(opt.n_epochs):
    print('Epoch {}'.format(epoch))
    for i, (batch, labels) in enumerate(dataloader):

        # Labels for real and fake images
        real = Variable(Tensor(batch.size(0), 1).fill_(1), requires_grad=False)
        fake = Variable(Tensor(batch.size(0), 1).fill_(0), requires_grad=False)

        # One-hot encode labels
        labels_onehot = Variable(Tensor(batch.size(0), opt.n_classes).zero_())
        labels_ = labels.type(LongTensor)  # Ensure labels are LongTensor
        labels_ = labels_.view(batch.size(0), 1)
        labels_onehot = labels_onehot.scatter_(1, labels_, 1)

        # Real and fake images
        imgs_real = Variable(batch.type(Tensor))
        noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1))
        imgs_fake = generator(noise, labels_onehot)

        # == Discriminator update == #
        optimizer_D.zero_grad()
        d_loss = gan_loss(discriminator(imgs_real, labels_onehot), real) + \
                 gan_loss(discriminator(imgs_fake, labels_onehot), fake)
        d_loss.backward()
        optimizer_D.step()

        # == Generator update == #
        noise = Variable(Tensor(batch.size(0), opt.latent_dim).normal_(0, 1))
        imgs_fake = generator(noise, labels_onehot)
        optimizer_G.zero_grad()
        g_loss = gan_loss(discriminator(imgs_fake, labels_onehot), real)
        g_loss.backward()
        optimizer_G.step()

        # == Visdom updates == #
        if vis:
            batches_done = epoch * len(dataloader) + i
            print(f"Epoch: {epoch}, Batch: {i}, Batches Done: {batches_done}")
            if batches_done % opt.sample_interval >= 0:

                # Append losses for plotting
                epochs.append(batches_done)
                g_losses.append(g_loss.item())
                d_losses.append(d_loss.item())

                # Update loss plot
                vis.line(
                    X=torch.tensor([batches_done]),
                    Y=torch.tensor([[d_loss.item(), g_loss.item()]]),
                    win=1,
                    update='append' if batches_done > 0 else None,
                    opts={
                        'title': 'Loss over time',
                        'legend': ['D Loss', 'G Loss'],
                        'xlabel': 'Batches Done',
                        'ylabel': 'Loss',
                        'width': 512,
                        'height': 512,
                    }
                )

                # Update generated images
                noise = Variable(Tensor(5 * 10, opt.latent_dim).normal_(0, 1))
                labels_onehot = Variable(Tensor(5 * 10, opt.n_classes).zero_())
                labels_ = torch.arange(0, 10).repeat(5, 1).transpose(0, 1).contiguous().view(-1, 1)
                labels_ = labels_.type(LongTensor)
                labels_onehot = labels_onehot.scatter_(1, labels_, 1)
                imgs_fake = generator(noise, labels_onehot)

                vis.images(
                    imgs_fake.data[:50],
                    nrow=5,
                    win=2,
                    opts={
                        'title': 'Generated Images [Epoch {}, Batch {}]'.format(epoch, i),
                        'width': 512,
                        'height': 512,
                    }
                )
                print(f"Epoch {epoch}, Batch {i}, Batches Done: {batches_done}")
                print(f"D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")



Epoch 0
torch.Size([16, 3, 256, 341])
torch.Size([16, 3, 256, 341])
torch.Size([16, 3, 256, 341])
Epoch: 0, Batch: 0, Batches Done: 0
Epoch 0, Batch 0, Batches Done: 0
D Loss: 1.3863874673843384, G Loss: 0.6840237975120544
torch.Size([16, 3, 256, 341])
torch.Size([16, 3, 256, 341])
torch.Size([16, 3, 256, 341])
Epoch: 0, Batch: 1, Batches Done: 1
Epoch 0, Batch 1, Batches Done: 1
D Loss: 1.355218529701233, G Loss: 0.6666851043701172
torch.Size([16, 3, 256, 341])
torch.Size([16, 3, 256, 341])
torch.Size([16, 3, 256, 341])
Epoch: 0, Batch: 2, Batches Done: 2
Epoch 0, Batch 2, Batches Done: 2
D Loss: 1.283057451248169, G Loss: 0.5031737089157104


KeyboardInterrupt: 

In [None]:

# Generate images
print(imgs_fake[j].shape)
# Generate images
noise = Variable(Tensor(5*10, opt.latent_dim).normal_(0, 1))
labels_onehot = Variable(Tensor(5*10, opt.n_classes).zero_())
labels_ = torch.range(0, 9)
labels_ = labels_.view(1, -1).repeat(5, 1).transpose(0, 1).contiguous().view(1, -1)
labels_ = labels_.type(LongTensor)
labels_ = labels_.view(-1, 1)
labels_onehot = labels_onehot.scatter_(1, labels_, 1)
imgs_fake = generator(noise, labels_onehot)

# Save the generated images with labels overlaid
save_dir = 'C:/Users/Timoy/Documents/GitHub/evml-evd3-project/Project_2/prep/prepgenerated_images'
os.makedirs(save_dir, exist_ok=True)

# Convert each generated image to PIL and overlay the label
for j in range(imgs_fake.size(0)):
    img = imgs_fake[j].cpu().detach().numpy().transpose(1, 2, 0)  # Convert tensor to numpy (H, W, C)

    # Check the image range before applying any transformations
    print(f"Image {j} min: {img.min()}, max: {img.max()}")  # Ensure the range is [-1, 1]

    # Denormalize the image (convert from [-1, 1] to [0, 1])
    img = (img + 1) / 2
    
    # Ensure the pixel values are in the range [0, 255]
    img = (img * 255).clip(0, 255).astype('uint8')  # Convert to uint8 format for PIL
    
    # Squeeze the channel dimension (1, H, W) -> (H, W) for grayscale images
    img = img.squeeze(axis=2)  # Remove the single channel dimension for grayscale images
    
    # Convert numpy array to PIL image
    pil_img = Image.fromarray(img)
    
    # Overlay text (label) on the image
    draw = ImageDraw.Draw(pil_img)
    label = int(labels_[j].item())  # Get the label for the image
    draw.text((10, 10), f'Label: {label}', fill='white')  # Draw text on image (top-left corner)
    
    # Save the image
    pil_img.save(os.path.join(save_dir, f'epoch_{epoch}_batch_{i}_img_{j}_label_{label}.png'))



torch.Size([1, 256, 341])


  labels_ = torch.range(0, 9)


Image 0 min: -0.1117442175745964, max: 0.4447222650051117
Image 1 min: -0.06747356057167053, max: 0.3155316710472107
Image 2 min: -0.768531084060669, max: 0.9701910614967346
Image 3 min: -0.0633787140250206, max: 0.26418811082839966
Image 4 min: -0.518787682056427, max: 0.8570298552513123
Image 5 min: -0.33012157678604126, max: 0.7592254877090454
Image 6 min: -0.124612957239151, max: 0.4686332643032074
Image 7 min: -0.8240763545036316, max: 0.9849047660827637
Image 8 min: -0.09766220301389694, max: 0.3914071023464203
Image 9 min: -0.6644922494888306, max: 0.9387771487236023
Image 10 min: -1.0, max: 1.0
Image 11 min: -1.0, max: 1.0
Image 12 min: -1.0, max: 1.0
Image 13 min: -0.9999998211860657, max: 1.0
Image 14 min: -1.0, max: 1.0
Image 15 min: -0.7621409893035889, max: 0.953784167766571
Image 16 min: -0.6685002446174622, max: 0.9178050756454468
Image 17 min: -0.36280301213264465, max: 0.6913214325904846
Image 18 min: -0.6035997867584229, max: 0.8606859445571899
Image 19 min: -0.661845