In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms, models
import torchvision
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.tensorboard import SummaryWriter

import os
from PIL import Image
from IPython.display import display

import warnings
warnings.filterwarnings('ignore')

import torch.optim as optim

In [2]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, labels_csv, transform=None):
        

        self.labels_df = pd.read_csv(labels_csv)

        self.label_map = {row["FeatID"]: row["class"] for index, row in self.labels_df.iterrows()}
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.label_map)

    def __getitem__(self, idx):
        feat_id = list(self.label_map.keys())[idx]
        img_name = f"featID {feat_id}.png"
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path)
        label = self.label_map[feat_id]

        if self.transform:
            image = self.transform(image)

        return image, label

In [3]:
image_directory = '/Users/karthik/GANS/4by4-TOPO-24X224/renamed_images'
labels_csv = '/Users/karthik/AeruginosaWithClass.csv'

**Parameters**

In [4]:
n_epochs = 200

image_dim = 1*224*224

n_class = 5 

latent_dim = 100

lr = 0.0002

b1 = 0.5

b2 = 0.999

batch_size = 64

In [5]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [6]:
dataset = CustomDataset(image_directory, labels_csv, transform=transform)

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for images, labels in data_loader:

    pass

In [7]:
total_images = len(dataset)
total_labels = len(set(dataset.labels_df['class']))  

print(f"Total images: {total_images}")
print(f"Total labels: {total_labels}")

Total images: 2100
Total labels: 5


In [8]:
total_batches = len(data_loader)
print(f"Total batches: {total_batches}")

Total batches: 33


In [9]:
for i, (images, labels) in enumerate(data_loader, 1):
    print(f"Batch {i} has {len(images)} images and {len(labels)} labels")
    
    break

Batch 1 has 64 images and 64 labels


**Creating a Loop over the data loader to get the labels in the first batch**

In [10]:
first_batch_labels = next(iter(data_loader))[1] 

print("Labels in the first batch:")
print(first_batch_labels)

Labels in the first batch:
tensor([1, 1, 3, 2, 0, 3, 0, 0, 1, 1, 1, 0, 2, 0, 3, 3, 0, 3, 2, 0, 3, 0, 0, 0,
        3, 1, 1, 4, 2, 1, 0, 0, 1, 3, 0, 0, 3, 4, 1, 3, 0, 3, 0, 1, 1, 1, 0, 3,
        2, 0, 3, 0, 0, 4, 3, 1, 1, 3, 2, 1, 0, 1, 2, 2])


**Generator Architecture**

In [14]:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.label_emb = nn.Embedding(n_class, n_class)

        # Define the fully connected layers
        self.fc1 = nn.Linear(latent_dim + n_class, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 512)
        self.fc4 = nn.Linear(512, 1024)
        self.fc5 = nn.Linear(1024, image_dim)

        # Define the batch normalization layers
        self.batch_norm2 = nn.BatchNorm1d(256, 0.8)
        self.batch_norm3 = nn.BatchNorm1d(512, 0.8)
        self.batch_norm4 = nn.BatchNorm1d(1024, 0.8)

    def forward(self, noise, labels):
        
        # Concatenate label embedding and image to produce input
        x = torch.cat((self.label_emb(labels), noise), -1)
        x = F.leaky_relu(self.fc1(x), 0.2)

        x = self.fc2(x)
        x = self.batch_norm2(x)
        x = F.leaky_relu(x, 0.2)

        x = self.fc3(x)
        x = self.batch_norm3(x)
        x = F.leaky_relu(x, 0.2)

        x = self.fc4(x)
        x = self.batch_norm4(x)
        x = F.leaky_relu(x, 0.2)

        x = self.fc5(x)
        img = torch.tanh(x)
        img = img.view(img.size(0), image_dim)
        return img

**Discriminator Architecture**

In [18]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(n_class, n_class)

        self.model = nn.Sequential(
            nn.Linear(n_class + image_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), dim=-1)
        validity = self.model(d_in)
        return validity

**Optiminser**

In [19]:
adversarial_loss = torch.nn.MSELoss()

generator = Generator()
discriminator = Discriminator()

In [20]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

**Tensorboard Summary**

from torch.utils.tensorboard import SummaryWriter 

writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

**Training Loop**

In [22]:
FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor

In [24]:
for epoch in range(n_epochs):
    for i, (imgs, labels) in enumerate(data_loader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = FloatTensor(batch_size, 1).fill_(1.0)
        fake = FloatTensor(batch_size, 1).fill_(0.0)

        # Configure input
        real_imgs = imgs.type(FloatTensor)
        labels = labels.type(LongTensor)

        # Training Generator
        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim)))
        gen_labels = LongTensor(np.random.randint(0, n_class, batch_size))

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # Training Discriminator
        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(data_loader), d_loss.item(), g_loss.item())
        )


[Epoch 0/200] [Batch 0/33] [D loss: 0.494994] [G loss: 0.929005]
[Epoch 0/200] [Batch 1/33] [D loss: 2.778057] [G loss: 0.878239]
[Epoch 0/200] [Batch 2/33] [D loss: 0.280956] [G loss: 0.863225]
[Epoch 0/200] [Batch 3/33] [D loss: 0.107295] [G loss: 0.843420]
[Epoch 0/200] [Batch 4/33] [D loss: 0.235785] [G loss: 0.823319]
[Epoch 0/200] [Batch 5/33] [D loss: 0.083146] [G loss: 0.813332]
[Epoch 0/200] [Batch 6/33] [D loss: 0.123867] [G loss: 0.787185]
[Epoch 0/200] [Batch 7/33] [D loss: 0.083375] [G loss: 0.755836]
[Epoch 0/200] [Batch 8/33] [D loss: 0.093909] [G loss: 0.731295]
[Epoch 0/200] [Batch 9/33] [D loss: 0.063474] [G loss: 0.705369]
[Epoch 0/200] [Batch 10/33] [D loss: 0.057049] [G loss: 0.662764]
[Epoch 0/200] [Batch 11/33] [D loss: 0.054767] [G loss: 0.638120]
[Epoch 0/200] [Batch 12/33] [D loss: 0.065212] [G loss: 0.608088]
[Epoch 0/200] [Batch 13/33] [D loss: 0.068332] [G loss: 0.590467]
[Epoch 0/200] [Batch 14/33] [D loss: 0.069035] [G loss: 0.592488]
[Epoch 0/200] [Batch

[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 