In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [4]:
import pickle
with open('/kaggle/input/dataset/ordered_embeddings_dict.pkl', 'rb') as pickle_file:
    ordered_embeddings_dict = pickle.load(pickle_file)

  return torch.load(io.BytesIO(b))


In [5]:
import pickle
with open('/kaggle/input/dataset/ordered_image_tensor_dict (1).pkl', 'rb') as pickle_file:
    ordered_image_tensor_dict = pickle.load(pickle_file)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GenerateC(nn.Module):
    def __init__(self):
        super(GenerateC, self).__init__()
        
    def forward(self, x):
        mean = x[:, :128]  
#         print("mean:" , mean.shape)
        log_sigma = x[:, 128:]
#         print("logsigma:", log_sigma.shape)
        stddev = torch.exp(log_sigma)
#         print("stdev shape:",stddev.shape)
        epsilon = torch.randn(mean.shape[0], mean.shape[1], device=mean.device)
#         print("epsilon shape:",epsilon.shape)
        c = stddev * epsilon + mean
        return c

class ConditionalAugmentation(nn.Module):
    def __init__(self):
        super(ConditionalAugmentation, self).__init__()
        self.fc = nn.Linear(768, 256)  # Adjusted to 768 input and 256 output
        self.lrelu = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        x = self.fc(x)
        x = self.lrelu(x)
        return x

class EmbeddingCompressor(nn.Module):
    def __init__(self):
        super(EmbeddingCompressor, self).__init__()
        self.fc = nn.Linear(768, 256)  # Adjusted to 768 input and 256 output
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.fc(x)
        x = self.relu(x)
        return x

class Stage1Generator(nn.Module):
    def __init__(self):
        super(Stage1Generator, self).__init__()
        self.fc1 = nn.Linear(768, 256)  # Adjusted to 768 input and 256 output
        self.lrelu = nn.LeakyReLU(0.2)
        self.generate_c = GenerateC()
        
        self.fc2 = nn.Linear(128 + 100, 128 * 8 * 4 * 4)  # Adjusted to 256 + 100
        self.relu = nn.ReLU()
        self.reshape = nn.Unflatten(1, (128 * 8, 4, 4))
        
        self.upconv1 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128 * 8, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        
        self.upconv2 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        
        self.upconv3 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        self.upconv4 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1, bias=False)
        self.tanh = nn.Tanh()
        
    def forward(self, x1, x2):
        x1= x1.squeeze(1)
        x1=self.fc1(x1)
        mean_logsigma = self.lrelu(x1)
#         print("mean_logsigma shape:", mean_logsigma.shape)
        c = self.generate_c(mean_logsigma)
#         print("c shape",c.shape)
        gen_input = torch.cat([c, x2], dim=1)
#         print("shape after concatenate:", gen_input.shape)
        
        x = self.fc2(gen_input)
        x = self.relu(x)
        x = self.reshape(x)
        
        x = self.upconv1(x)
        x = self.upconv2(x)
        x = self.upconv3(x)
        x = self.upconv4(x)
        x = self.final_conv(x)
        x = self.tanh(x)
        
        return x, mean_logsigma

class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = F.relu(out)
        return out

class JointBlock(nn.Module):
    def forward(self, c, x):
        c = c.unsqueeze(2).unsqueeze(3)
        c = c.expand(-1, -1, x.size(2), x.size(3))
        return torch.cat([c, x], dim=1)

class Stage2Generator(nn.Module):
    def __init__(self):
        super(Stage2Generator, self).__init__()
        
        # CA Augmentation Network
        self.fc = nn.Linear(768, 256)
        self.leaky_relu = nn.LeakyReLU(0.2)
        
        # Image Encoder
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3,stride=1, padding=1)
        self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(512)
        self.reduce_channels = nn.Conv2d(768, 512, kernel_size=1)
        
        # Residual Blocks
        self.residual_block1 = ResidualBlock(512)
        self.residual_block2 = ResidualBlock(512)
        self.residual_block3 = ResidualBlock(512)
        self.residual_block4 = ResidualBlock(512)
        
        # Upsampling Layers
        self.upsample1 = nn.Upsample(scale_factor=2)
        self.conv_up1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn_up1 = nn.BatchNorm2d(512)
        
        self.upsample2 = nn.Upsample(scale_factor=2)
        self.conv_up2 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn_up2 = nn.BatchNorm2d(256)
        
        self.upsample3 = nn.Upsample(scale_factor=2)
        self.conv_up3 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn_up3 = nn.BatchNorm2d(128)
        
        self.upsample4 = nn.Upsample(scale_factor=2)
        self.conv_up4 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn_up4 = nn.BatchNorm2d(64)
        
        # Final layer
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.tanh = nn.Tanh()
        
    def forward(self, z, lr_img):
        # CA Augmentation
        z = z.squeeze(1)
        c = self.fc(z)
        c = self.leaky_relu(c)
        
        # Image Encoder
        x = F.pad(lr_img, (1, 1, 1, 1))
        x = F.relu(self.conv1(x))
        
        x = F.pad(x, (1, 1, 1, 1))
        x = F.relu(self.bn2(self.conv2(x)))
        
        x = F.pad(x, (1, 1, 1, 1))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Joint block
        joint_block = JointBlock()
        c_code = joint_block(c, x)
        c_code = self.reduce_channels(c_code)
        # print(c_code.shape)
        
        # Residual blocks
        x = self.residual_block1(c_code)
        x = self.residual_block2(x)
        x = self.residual_block3(x)
        x = self.residual_block4(x)
        
        # Upsampling blocks
        x = self.upsample1(x)
        x = F.relu(self.bn_up1(self.conv_up1(x)))
        
        x = self.upsample2(x)
        x = F.relu(self.bn_up2(self.conv_up2(x)))
        
        x = self.upsample3(x)
        x = F.relu(self.bn_up3(self.conv_up3(x)))
        
        x = self.upsample4(x)
        x = F.relu(self.bn_up4(self.conv_up4(x)))
        
        # Final output
        x = self.final_conv(x)
        x = self.tanh(x)
        
        return x, c


class Stage2Discriminator(nn.Module):
    def __init__(self):
        super(Stage2Discriminator, self).__init__()
        self.fc_2 = nn.Linear(768, 128)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.conv5 = nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1)
        self.conv6 = nn.Conv2d(1024, 2048, kernel_size=4, stride=2, padding=1)
        
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)
        self.bn5 = nn.BatchNorm2d(1024)
        self.bn6 = nn.BatchNorm2d(2048)
        
        self.conv7 = nn.Conv2d(2048, 1024, kernel_size=1, stride=1, padding=0)
        self.bn7 = nn.BatchNorm2d(1024)
        
        self.conv8 = nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0)
        self.bn8 = nn.BatchNorm2d(512)
        
        self.conv9 = nn.Conv2d(640, 128, kernel_size=1, stride=1, padding=0)
        self.bn9 = nn.BatchNorm2d(128)
        
        self.conv10 = nn.Conv2d(128, 512, kernel_size=3, stride=1, padding=1)
        self.bn10 = nn.BatchNorm2d(512)
        
        self.conv_final = nn.Conv2d(512, 1, kernel_size=4, stride=1)
        
    def forward(self, img, embd):
        
        x = F.leaky_relu(self.conv1(img), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2)
        x = F.leaky_relu(self.bn5(self.conv5(x)), 0.2)
        x = F.leaky_relu(self.bn6(self.conv6(x)), 0.2)
        
        x = F.leaky_relu(self.bn7(self.conv7(x)), 0.2)
        x = F.leaky_relu(self.bn8(self.conv8(x)), 0.2)
        
        # Joint with c_code
        # c_code= c_code.squeeze(1)
        c_code = self.fc_2(embd)
        c_code= c_code.squeeze(1)
        c_code = c_code.unsqueeze(2).unsqueeze(3)
        c_code = c_code.expand(c_code.size(0), c_code.size(1), x.size(2), x.size(3))
        x = torch.cat([x, c_code], dim=1)
        
        x = F.leaky_relu(self.bn9(self.conv9(x)), 0.2)
        x = F.leaky_relu(self.bn10(self.conv10(x)), 0.2)
        
        validity = self.conv_final(x)
        return torch.sigmoid(validity)

class AdversarialModel(nn.Module):
    def __init__(self, gen_model1, gen_model2, dis_model):
        """
        Initialize adversarial model.
        """
        super(AdversarialModel, self).__init__()
        self.gen_model1 = gen_model1
        self.gen_model2 = gen_model2
        self.dis_model = dis_model
        
        # Freeze gen_model1 and dis_model parameters (set them to eval mode)
        self.gen_model1.eval()
        self.dis_model.eval()
        

    def forward(self, embeddings_input, noise_input):

        # Pass through gen_model1 to get low-resolution images and mean_logsigma1
        with torch.no_grad():  # Freeze the first generator
            lr_images, mean_logsigma1 = self.gen_model1(embeddings_input, noise_input)
        
        # Pass low-resolution images through gen_model2 to get high-resolution images and mean_logsigma2
        hr_images, mean_logsigma2 = self.gen_model2(embeddings_input, lr_images)
        
        # Pass high-resolution images and compressed embeddings through the discriminator
        valid = self.dis_model(hr_images, embeddings_input)
        
        return valid, mean_logsigma2


In [7]:
def KL_loss(y_pred):
    # Extract mean and log_sigma from y_pred
    mean = y_pred[:, :128]
    log_sigma = y_pred[:, 128:]
    loss = -log_sigma + 0.5 * (-1 + torch.exp(2. * log_sigma) + mean**2)

    loss = loss.mean()
    
    return loss

In [8]:
import matplotlib.pyplot as plt
import numpy as np

def save_rgb_img(img, path):

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    if isinstance(img, torch.Tensor):
        img = img.detach().permute(1, 2, 0).cpu().numpy()  
    
    ax.imshow(img)
    ax.axis("off")
    ax.set_title("Image")

    plt.savefig(path)
    plt.close()
    


In [9]:
from torch.utils.tensorboard import SummaryWriter

def write_log(writer, name, loss, global_step):

    writer.add_scalar(name, loss, global_step)

In [10]:
def plot_generated_images(images, epoch, n_images=8):

    fig, axes = plt.subplots(1, n_images, figsize=(n_images * 2, 2))  
    for i in range(n_images):
        ax = axes[i]
        img = images[i].detach().permute(1, 2, 0).cpu().numpy()  
        ax.imshow(img)
        ax.axis("off")  
    plt.suptitle(f'Generated Images at Epoch {epoch}')  
    plt.show()  # Display the plot

In [11]:
import torch
import time
import gc
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo

def clear_gpu_memory():
    torch.cuda.empty_cache()
    gc.collect()
    del variables

def wait_until_enough_gpu_memory(min_memory_available, max_retries=10, sleep_time=5):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(torch.cuda.current_device())

    for _ in range(max_retries):
        info = nvmlDeviceGetMemoryInfo(handle)
        if info.free >= min_memory_available:
            break
        print(f"Waiting for {min_memory_available} bytes of free GPU memory. Retrying in {sleep_time} seconds...")
        time.sleep(sleep_time)
    else:
        raise RuntimeError(f"Failed to acquire {min_memory_available} bytes of free GPU memory after {max_retries} retries.")

In [13]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import os
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import time

# Hyperparameters and setup
batch_size = 8
z_dim = 100
stage1_generator_lr = 0.0002
stage1_discriminator_lr = 0.0002
epochs = 250
condition_dim = 128
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a tensorboard writer
writer = SummaryWriter(log_dir="logs/".format(time.time()))

# Optimizers


# Loss functions
bce_loss = nn.BCELoss().to(device)

stage2_gen = Stage2Generator().to(device)
stage2_dis = Stage2Discriminator().to(device)
embedding_compressor_model = EmbeddingCompressor().to(device)
stage1_gen = Stage1Generator().to(device)
stage1_gen.load_state_dict(torch.load("/kaggle/input/weight-2/generator_epoch_40.pth"))  # Load the weights
stage1_gen.eval()  # Set to evaluation mode
adversarial_model = AdversarialModel(stage1_gen,stage2_gen, stage2_dis).to(device)

dis_optimizer = optim.Adam(stage2_dis.parameters(), lr=stage1_discriminator_lr, betas=(0.5, 0.999))
gen_optimizer = optim.Adam(stage2_gen.parameters(), lr=stage1_generator_lr, betas=(0.5, 0.999))

# Labels
real_labels = torch.full((batch_size, 1), 0.9, dtype=torch.float, device=device)
fake_labels = torch.full((batch_size, 1), 0.1, dtype=torch.float, device=device)

# Data preparation (assume dictionaries are in order)
image_keys = list(ordered_image_tensor_dict.keys())

# Training loop
for epoch in tqdm(range(epochs)):
    print(f"Epoch {epoch + 1}/{epochs}")

    gen_losses = []
    dis_losses = []

    num_batches = len(ordered_image_tensor_dict) // batch_size
    for index in tqdm(range(num_batches)):
        # Select batch
        batch_keys = image_keys[index * batch_size:(index + 1) * batch_size]
        real_images = torch.stack([ordered_image_tensor_dict[key] for key in batch_keys]).to(device)
        embeddings = torch.stack([ordered_embeddings_dict[key] for key in batch_keys]).to(device)

        # Normalize real images
        # real_images = (real_images - 127.5) / 127.5

        # Generate noise and fake images
        z_noise = torch.randn(batch_size, z_dim, device=device)
        lr_fake_images, _ = stage1_gen(embeddings, z_noise)
        hr_fake_images, _ = stage2_gen(embeddings, lr_fake_images)

        # Train the discriminator
        stage2_dis.zero_grad()

        # Train with real images
        dis_real_output = stage2_dis(real_images, embeddings).view(-1)  # Flatten the output
        dis_loss_real = bce_loss(dis_real_output, real_labels.view(-1))  # Flatten the labels

# Train with fake images
        dis_fake_output = stage2_dis(hr_fake_images.detach(), embeddings).view(-1)  # Flatten the output
        dis_loss_fake = bce_loss(dis_fake_output, fake_labels.view(-1))

        wrong_output = stage2_dis(real_images[:(batch_size - 1)], embeddings[1:]).view(-1)  # Flatten the output
        wrong_loss = bce_loss(wrong_output, fake_labels[1:].view(-1).to(device))  
        
        # Combine losses
        d_loss = 0.5 * (dis_loss_real + 0.5 * (wrong_loss + dis_loss_fake))
        d_loss.backward()
        dis_optimizer.step()

        # Train the generator
        stage2_gen.zero_grad()

        # Adversarial loss and KL loss
        valid, mean_logsigma2 = adversarial_model(embeddings, z_noise)
        g_loss_adv = bce_loss(valid.squeeze(), real_labels.view(-1))
        g_loss_kl = KL_loss(mean_logsigma2)
        g_loss = g_loss_adv + 2 * g_loss_kl

        g_loss.backward()
        gen_optimizer.step()

        # Record losses
        gen_losses.append(g_loss.item())
        dis_losses.append(d_loss.item())
    
        if index % 50 == 0:
            print(f'Epoch [{epoch}/{epochs}] Step [{index}] Discriminator Loss: {d_loss.item()} Generator Loss: {g_loss.item()}')

    writer.add_scalar('Discriminator Loss', np.mean(dis_losses), epoch)
    writer.add_scalar('Generator Loss', np.mean(gen_losses), epoch)
            
    if epoch % 10 == 0:
        torch.save(stage2_gen.state_dict(), f'generator_epoch_{epoch}.pth')
        torch.save(stage2_dis.state_dict(), f'discriminator_epoch_{epoch}.pth')
    # Print average losses per epoch
    if epoch % 10 == 0:
        img_path = os.path.join(output_dir, f'generated_img_epoch_{epoch}.png')
        save_rgb_img(hr_fake_images[0], img_path)  # Save the first image in the batch for visualization
        
        # Plot generated images
        plot_generated_images(fake_images, epoch)  
    # Save generated images
    if epoch % 2 == 0:
        with torch.no_grad():
            z_noise2 = torch.randn(batch_size, z_dim, device=device)
            lr_fake_images, _ = stage1_gen(embeddings, z_noise2)
            hr_fake_images, _ = stage2_gen(embeddings, lr_fake_images)
            save_images(hr_fake_images, epoch, batch_size)

# Save the models after training
torch.save(stage2_gen.state_dict(), "stage2_gen.pth")
torch.save(stage2_dis.state_dict(), "stage2_dis.pth")


  stage1_gen.load_state_dict(torch.load("/kaggle/input/weight-2/generator_epoch_40.pth"))  # Load the weights
  0%|          | 0/250 [00:00<?, ?it/s]

Epoch 1/250



  0%|          | 0/102 [00:00<?, ?it/s][A
  0%|          | 0/250 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.27 GiB. GPU 0 has a total capacity of 14.74 GiB of which 520.12 MiB is free. Process 2412 has 14.23 GiB memory in use. Of the allocated memory 12.36 GiB is allocated by PyTorch, and 1.73 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)