In [None]:
import torch
import torch.nn as nn
import torchvision.utlis as vutils
from torch.utils.data import Dataloader
from data_util import text2ImageDataset
from utils import process_caption, weights_init
from text2image import Generator, Discriminator
import os
import time
import imageio
from datetime import datetime
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [None]:
device =torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using the Device which is ", device)

In [None]:
date = datetime.now().strftime("%y%m%d")
start_time = time.time()

In [None]:
output_save_path = './generated_images/'
os.makedirs(output_save_path, exist_ok=True)

In [None]:
model_save_path = './saved_models/'
os.makedirs(model_save_path, exist_ok=True)

In [None]:
#setting parameters
noise_dimension = 100
embed_dimension = 1024
embed_out_dimension = 128
batch_size = 256
real_label = 1.
fake_label = 0.
learning_rate =0.0002
l1_coefficient = 50
l2_coefficient = 100
num_of_epochs =250
log_interval = 18

In [None]:
tranin_dataset = text2ImageDataset('', split=0)
train_loader = Dataloader(tranin_dataset, batch_size = batch_size, shuffle=True, num_workers = 8)
print("Number of Batches: ", len(train_loader))

In [None]:
#loss functions
criterion = nn.BCELoss()
l2_loss = nn.MSELoss()
l1_loss = nn.L1Loss()
#storee losses list
D_losses = []
G_losses =[]

# Training

In [None]:
#initializing Generator
generator = Generator(channels=3, embed_dimension= embed_dimension, noise_dimension=noise_dimension, embed_output_dimension=embed_out_dimension).to(device)
generator.aaply(weights_init)

In [None]:
# initializing Discriminator
discriminator = Discriminator(channels=3, embed_dimension=embed_dimension, embed_output_dimension=embed_out_dimension).to(device)
discriminator.apply(weights_init)

In [None]:
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas = (0.5, 0.999))
optimizer_discriminarator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas = (0.5, 0.999))


In [None]:
#training loop
for epoch in range(num_of_epochs):
    batch_time = time.time()
    for batch_index, batch in enumerate(train_loader):
        images =batch['correct_images'].to(device)
        wrong_images = batch['wrong_images'].to(device)
        embeddings =batch['correct_embed'].to(device)
        batch_size = images.size(0)
        optimizer_discriminarator.zero_grad()
        noise = torch.randn(batch_size, noise_dimension, 1,1, device = device)
        fake_images =generator(noise, embeddings)
        real_out, real_act = discriminator(images, embeddings)
        d_loss_real =criterion(real_out, torch.full_like(real_out, real_label, device=device))
        wrong_out, wrong_act = discriminator(wrong_images, embeddings)
        d_loss_wrong =criterion(wrong_out. torch.full_like(wrong_out. fake_label, device=device))
        fake_out, fake_act = discriminator(fake_images.detach(), embeddings)
        d_loss_fake =criterion(fake_out, torch.full_like(fake_out, fake_label, device =device))
        d_loss = d_loss_real +d_loss_wrong+d_loss_fake
        d_loss.backward()
        optimizer_discriminarator.step()
        optimizer_generator.zero_grad()
        noise = torch.randn(batch_size, noise_dimension, 1,1)
        fake_images = generator(noise, embeddings)
        out_fake, act_fake = discriminator(fake_images, embeddings)
        out_real, act_real = discriminator(images, embeddings)
        g_cbe = criterion(out_fake, torch.full_like(out_fake, real_label, device=device))
        gl1 = l1_coefficient *l1_loss(fake_images, images)
        gl2 = l2_loss(torch.mean(act_fake, 0), torch.mean(act_real, 0).detach())

        g_loss = g_cbe + gl1+gl2
        g_loss.backward()
        optimizer_generator.step()
        D_losses.append(d_loss.item())
        G_losses.append(g_loss.item())
        if (batch_index+1)% log_interval ==0 and batch_index>0:
            print('Epoch {} [{}/{}] loss_discriminator: {: .4f} loss_genarator: {: .4f} time:{: .2f}'.format(
                epoch+1, batch_index+1, len(train_loader),
                d_loss.mean().item(),g_loss.mean().iteam(), time.time()-batch_time
            ))
        if batch_index == len(train_loader) -1 and ((epoch+1)%10 ==0 or epoch ==0):
            viz_sample = torch.cat((images[:32], fake_images[:32],0))
            vutils.save_image(viz_sample, os.path.join(output_save_path, "output_{}_epoch_{}.png".fromat(date, epoch+1)), nrow =8, normalize =True)
        torch.save(generator.state_dict(), os.path.join(model_save_path, 'generator_{}.pth'.format(date)))
        torch.save(discriminator.state_dict(), os.path.join(model_save_path, 'discriminator_{}.pth'.format(date)))
        print('Total train time: {: .2f}'.format(time.time() - start_time))


In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator Loss during Training")
plt.plot(G_losses)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.show()
plt.savefig(os.path.join(output_save_path, 'output_geenration_Loss_{}.png').fromat(date))


In [None]:
# discriminator loss plot
plt.figure(figsize=(10,5))
plt.title("Discriminator Loss During Training")
plt.plot(D_losses)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.show()

plt.savefig(os.path.join(output_save_path, 'output_discriminatorLoss_{}.png'.format(date)))

In [None]:
# Get all file names from the "generated_images" directory
file_names = os.listdir(output_save_path)
file_names = [name for name in file_names if name.startswith('output_{}_'.format(date))]

# Sort the file names numerically
file_names = sorted(file_names, key=lambda name: int(name.split('_')[3].split('.')[0]))

# Create a list to store the read images
images = []

for file_name in file_names:
    images.append(imageio.imread(os.path.join(output_save_path,file_name)))

imageio.mimsave(os.path.join(output_save_path, 'output_gif_{}.gif'.format(date)), images, fps=1) 

In [None]:
from IPython.display import Image

# Load the GIF
with open(os.path.join(output_save_path, 'output_gif_{}.gif'.format(date)),'rb') as file:
    display(Image(file.read()))