In [45]:
import os
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


In [46]:

print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


True
Using device: cuda


In [47]:

# Reading all images from directory
Images = {}
Folder = 'Grouped_Output'
group_to_id = {}
for i, filename in enumerate(os.listdir(Folder)):
    Images[filename] = []
    group_to_id[filename] = i
    for img in os.listdir(Folder + '/' + filename):
        Images[filename].append(img)

num_groups = len(group_to_id)


In [48]:

# Pre-processing the images
transform = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])

# Applying transformation to all images
Images_tensor = {}
for folder in Images.keys():
    Images_tensor[folder] = []
    for img in Images[folder]:
        img = Image.open(Folder + '/' + folder + '/' + img)
        img = transform(img).to(device)
        Images_tensor[folder].append(img)


In [49]:

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, num_groups):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim + num_groups, 256 * 4 * 4)
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)

    def forward(self, z, group):
        z = torch.cat([z, group], dim=1)
        z = self.fc(z)
        z = z.view(z.size(0), 256, 4, 4)
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        z = F.relu(self.deconv3(z))
        z = torch.sigmoid(self.deconv4(z))
        return z


In [50]:
class VAE(nn.Module):
    def __init__(self, latent_dim, num_groups):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim, num_groups)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, group):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        group = group.unsqueeze(0).expand(z.size(0), -1)
        return self.decoder(z, group), mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [51]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_groups):
        super(Generator, self).__init__()
        self.fc = nn.Linear(latent_dim + num_groups, 256 * 4 * 4)  # Fully connected layer to map latent vector
        
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # Output: 128x8x8
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)   # Output: 64x16x16
        self.deconv3 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)    # Output: 32x32x32
        self.deconv4 = nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1)     # Output: 3x64x64

    def forward(self, z, group):
        group = group.unsqueeze(0).expand(z.size(0), -1)  # Add group info to latent vector
        x = torch.cat([z, group], dim=1)
        x = F.relu(self.fc(x))  # Fully connected layer
        x = x.view(x.size(0), 256, 4, 4)  # Reshape to 256x4x4 for ConvTranspose2d
        
        x = F.relu(self.deconv1(x))  # 128x8x8
        x = F.relu(self.deconv2(x))  # 64x16x16
        x = F.relu(self.deconv3(x))  # 32x32x32
        x = torch.tanh(self.deconv4(x))  # 3x64x64 (tanh is often used to keep values between -1 and 1)
        
        return x


class Discriminator(nn.Module):
    def __init__(self, num_groups):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Conv2d(3 + num_groups, 32, kernel_size=4, stride=2, padding=1)  # Input: 3+num_groups channels, Output: 32x32x32
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # Output: 64x16x16
        self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)  # Output: 128x8x8
        self.conv4 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)  # Output: 256x4x4
        
        self.fc = nn.Linear(256 * 4 * 4, 1)  # Fully connected layer for binary classification

    def forward(self, x, group):
        group = group.view(group.size(0), -1, 1, 1)  # Reshape group embedding to be broadcastable
        group = group.expand(group.size(0), group.size(1), 64, 64)  # Expand to match image size (64x64)
        x = torch.cat([x, group], dim=1)  # Concatenate image and group embedding along the channel dimension

        x = F.leaky_relu(self.conv1(x), 0.2)  # 32x32x32
        x = F.leaky_relu(self.conv2(x), 0.2)  # 64x16x16
        x = F.leaky_relu(self.conv3(x), 0.2)  # 128x8x8
        x = F.leaky_relu(self.conv4(x), 0.2)  # 256x4x4

        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.sigmoid(self.fc(x))  # Binary output (real or fake)
        
        return x


In [52]:
# Initialize models
latent_dim = 20
vae = VAE(latent_dim=latent_dim, num_groups=num_groups).to(device)
generator = Generator(latent_dim, num_groups).to(device)
discriminator = Discriminator(num_groups).to(device)

# Optimizers
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-3)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

# Loss functions
adversarial_loss = nn.BCELoss()

In [53]:
def train_vae(num_epochs):
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        for folder in Images_tensor.keys():
            group_id = group_to_id[folder]
            group_onehot = F.one_hot(torch.tensor(group_id), num_classes=num_groups).float().to(device)
            for img in Images_tensor[folder]:
                img = img.unsqueeze(0)  # Add batch dimension
                recon_img, mu, logvar = vae(img, group_onehot)
                loss = vae_loss(recon_img, img, mu, logvar)
                optimizer_vae.zero_grad()
                loss.backward()
                optimizer_vae.step()
                total_loss += loss.item()
                num_batches += 1
        avg_loss = total_loss / num_batches
        print(f'VAE Epoch: {epoch+1}, Average Loss: {avg_loss:.4f}')


In [54]:
def train_gan(num_epochs, switch_epoch_G, switch_epoch_D, joint_train_start_epoch):
    for epoch in range(num_epochs):
        total_d_loss = 0
        total_g_loss = 0
        num_batches = 0
        for folder in Images_tensor.keys():
            group_id = group_to_id[folder]
            group_onehot = F.one_hot(torch.tensor(group_id), num_classes=num_groups).float().to(device)
            for img in Images_tensor[folder]:
                img = img.view(-1, 64*64*3)
                batch_size = img.size(0)
                
                # Train only Discriminator for specified epochs
                if epoch < switch_epoch_D or (epoch >= joint_train_start_epoch and epoch % 2 == 0):
                    real_label = torch.ones(batch_size, 1).to(device)
                    fake_label = torch.zeros(batch_size, 1).to(device)
                    
                    optimizer_D.zero_grad()
                    real_output = discriminator(img, group_onehot)
                    d_loss_real = adversarial_loss(real_output, real_label)
                    
                    z = torch.randn(batch_size, latent_dim).to(device)
                    fake_img = generator(z, group_onehot)
                    fake_output = discriminator(fake_img.detach(), group_onehot)
                    d_loss_fake = adversarial_loss(fake_output, fake_label)
                    
                    d_loss = d_loss_real + d_loss_fake
                    d_loss.backward()
                    optimizer_D.step()
                    
                    total_d_loss += d_loss.item()

                # Train only Generator for specified epochs
                if epoch >= switch_epoch_G or (epoch >= joint_train_start_epoch and epoch % 2 != 0):
                    optimizer_G.zero_grad()
                    z = torch.randn(batch_size, latent_dim).to(device)
                    fake_img = generator(z, group_onehot)
                    output = discriminator(fake_img, group_onehot)
                    g_loss = adversarial_loss(output, real_label)
                    g_loss.backward()
                    optimizer_G.step()

                    total_g_loss += g_loss.item()
                
                num_batches += 1
        
        avg_d_loss = total_d_loss / num_batches if total_d_loss > 0 else 0
        avg_g_loss = total_g_loss / num_batches if total_g_loss > 0 else 0
        print(f'Epoch {epoch+1}, Avg D Loss: {avg_d_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}')

In [55]:
# Training
vae_epochs = 25
gan_epochs = 250

In [56]:
print("Training VAE...")
train_vae(vae_epochs)

Training VAE...
VAE Epoch: 1, Average Loss: 7967.4876
VAE Epoch: 2, Average Loss: 7896.5509
VAE Epoch: 3, Average Loss: 7849.9287
VAE Epoch: 4, Average Loss: 7838.5224
VAE Epoch: 5, Average Loss: 7823.0136
VAE Epoch: 6, Average Loss: 7809.0817
VAE Epoch: 7, Average Loss: 7799.2621
VAE Epoch: 8, Average Loss: 7793.0351
VAE Epoch: 9, Average Loss: 7788.5469
VAE Epoch: 10, Average Loss: 7783.3222
VAE Epoch: 11, Average Loss: 7780.1386
VAE Epoch: 12, Average Loss: 7776.6827
VAE Epoch: 13, Average Loss: 7774.8272
VAE Epoch: 14, Average Loss: 7771.8794
VAE Epoch: 15, Average Loss: 7769.4377
VAE Epoch: 16, Average Loss: 7768.1334
VAE Epoch: 17, Average Loss: 7766.2308
VAE Epoch: 18, Average Loss: 7764.0874
VAE Epoch: 19, Average Loss: 7762.7476
VAE Epoch: 20, Average Loss: 7760.3775
VAE Epoch: 21, Average Loss: 7758.7426
VAE Epoch: 22, Average Loss: 7757.4540
VAE Epoch: 23, Average Loss: 7755.2532
VAE Epoch: 24, Average Loss: 7755.0627
VAE Epoch: 25, Average Loss: 7752.5326


In [57]:
# Parameters
switch_epoch_G = 5  # Train only Generator after epoch 5
switch_epoch_D = 3  # Train only Discriminator after epoch 3
joint_train_start_epoch = 10  # Start joint training from epoch 10
print("Training GAN...")
train_gan(gan_epochs, switch_epoch_G, switch_epoch_D, joint_train_start_epoch)


Training VAE...


RuntimeError: Tensors must have same number of dimensions: got 2 and 4

In [None]:
# Save the models
torch.save(vae.state_dict(), 'vae_model.pth')
torch.save(generator.state_dict(), 'generator_model.pth')
torch.save(discriminator.state_dict(), 'discriminator_model.pth')

In [None]:
def generate_signature(group_name):
    if group_name not in group_to_id:
        print(f"Group '{group_name}' not found.")
        return None
    
    group_id = group_to_id[group_name]
    group_onehot = F.one_hot(torch.tensor(group_id), num_classes=num_groups).float().to(device)
    
    z = torch.randn(1, latent_dim).to(device)
    fake_img = generator(z, group_onehot)
    fake_img = fake_img.view(3, 64, 64).permute(1, 2, 0).cpu().detach()
    
    plt.figure(figsize=(8, 8))
    plt.imshow(fake_img)
    plt.title(f"Generated Signature for Group: {group_name}")
    plt.axis('off')
    plt.savefig(f'Generated_Signature_{group_name}.png')
    plt.close()
    
    print(f"Generated signature for group '{group_name}' saved as 'Generated_Signature_{group_name}.png'")
    return fake_img

# Example usage:
generate_signature("Group_11")  # Replace "Group1" with the actual group name you want to generate

In [None]:
print(group_to_id)