In [None]:
import zipfile
import os

# kishresun2016@gmail.com -> /content/drive/MyDrive/Sem6/CIP_Team6_2025/Transformation_zip.zip
# kishreigns@gamil.com -> /content/drive/MyDrive/Transformation_zip.zip
# malarvannanm11@gmail.com -> /content/drive/MyDrive/Transformation_zip.zip
zip_path = "/content/drive/MyDrive/Transformation_zip.zip"  # Change to your uploaded zip file name
extract_path = "/content/transformation"
drive_checkpoint_link = "/content/drive/MyDrive/checkpoints/deepkeygen_checkpoint.pth"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print(f"Extracted dataset at: {extract_path}")

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/Transformation_zip.zip'

In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms, datasets, utils
import os
import kagglehub
import torch.autograd as autograd
import matplotlib.pyplot as plt
import pandas as pd

import torch.nn.functional as F
import csv

data = []
drive_checkpoint_link = "/content/drive/MyDrive/checkpoints/deepkeygen_checkpoint.pth"
# /checkpoints/deepkeygen_checkpoint.pth

# Function to download multiple datasets
def download_datasets(dataset_list):
    dataset_dirs = [kagglehub.dataset_download(dataset) for dataset in dataset_list]
    return dataset_dirs

# Function to load multiple datasets into a single DataLoader
def load_multiple_datasets(data_dirs, transform, batch_size):
    datasets_list = [datasets.ImageFolder(data_dir, transform=transform) for data_dir in data_dirs]
    combined_dataset = ConcatDataset(datasets_list)
    return DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

def save_checkpoint(generator, critic, optimizer_g, optimizer_d, epoch, filepath=drive_checkpoint_link):
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'critic_state_dict': critic.state_dict(),
        'optimizer_g_state_dict': optimizer_g.state_dict(),
        'optimizer_d_state_dict': optimizer_d.state_dict()
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved at epoch {epoch+1}")

def load_checkpoint(generator, critic, optimizer_g, optimizer_d, filepath=drive_checkpoint_link, device=None):
    if os.path.exists(filepath) and os.path.getsize(filepath) > 0: # Check if file exists and has content
        try:
            checkpoint = torch.load(filepath, map_location=device)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            critic.load_state_dict(checkpoint['critic_state_dict'])
            optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
            optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            print(f"Resuming training from epoch {start_epoch}")
        except RuntimeError as e:
            print(f"Error loading checkpoint: {e}") # Print error message if loading fails
            start_epoch = 0 # Start from epoch 0 if loading fails
            print("Starting training from scratch due to checkpoint loading error.")
    else:
        start_epoch = 0
        print("No checkpoint found, starting training from scratch.")

    return start_epoch


# Generator network
# Generator network (modified to output 256x256)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),  # (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),  # (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),  # (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),  # (64, 32, 32)
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),  # (32, 64, 64)
            nn.BatchNorm2d(32),
            nn.ReLU(True),

            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),  # (16, 128, 128)
            nn.BatchNorm2d(16),
            nn.ReLU(True),

            nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),  # (3, 256, 256)
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)


# Critic network
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 1, 4, 1, 0, bias=False)
        )

    def forward(self, img):
        return self.model(img).view(-1)

# Compute gradient penalty
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
    batch_size = real_samples.size(0)

    # Ensure alpha has the same shape as real_samples
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    alpha = alpha.expand(real_samples.shape)

    #print(f"alpha shape: {alpha.shape}")
    #print(f"real_samples shape: {real_samples.shape}")
    #print(f"fake_samples shape: {fake_samples.shape}")



    fake_samples = F.interpolate(fake_samples, size=(256, 256), mode='bilinear', align_corners=False)


    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    critic_interpolates = critic(interpolates)
    grad_outputs = torch.ones_like(critic_interpolates, device=device)

    gradients = torch.autograd.grad(
        outputs=critic_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty


# Save generated images
#def save_generated_images(generator, epoch, device, num_images=8):
 #   generator.eval()
  #  with torch.no_grad():
   #     z = torch.randn(num_images, 100, 1, 1, device=device)
    #    fake_images = generator(z)

        # Ensure images are exactly 256x256
     #   fake_images = F.interpolate(fake_images, size=(256, 256), mode='bilinear', align_corners=False)

      #  fake_images = (fake_images + 1) / 2  # Normalize from [-1,1] to [0,1]
       # os.makedirs("generated_images", exist_ok=True)
        #image_path = f"generated_images/epoch_{epoch}.png"
        #utils.save_image(fake_images, image_path, normalize=True, nrow=4)
        #print(f"Saved generated images at {image_path}")
    #generator.train()


def save_generated_images(generator, epoch, device):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(1, 100, 1, 1, device=device)
        fake_img = generator(z).squeeze(0)  # Shape: (3, 256, 256)
        os.makedirs("generated_images", exist_ok=True)
        image_path = f"generated_images/epoch_{epoch}.png"
        # Save image with correct size
        utils.save_image(fake_img, image_path, normalize=True)

    generator.train()



# Checkpoint and loss validation in the training loop:
def train_deepkeygen(generator, critic, source_loader, transform_loader, num_epochs, lr, device):
    optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
    optimizer_d = optim.Adam(critic.parameters(), lr=lr, betas=(0.5, 0.9))
    lambda_gp = 10
    critic_iterations = 5

    os.makedirs("checkpoints", exist_ok=True)
    start_epoch = load_checkpoint(generator, critic, optimizer_g, optimizer_d)

    csv_path = "metrics.csv"
    if not os.path.exists(csv_path):
        with open(csv_path, "w", newline="") as file:
            writer = csv.writer(file)
            writer.writerow(["Epoch", "Generator Loss", "Critic Loss", "Wasserstein Distance", "Gradient Penalty", "D_real", "D_fake"])

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch : {epoch+1}/{num_epochs}", flush=True)

        for (source_imgs, _), (transform_imgs, _) in zip(source_loader, transform_loader):

            min_batch_size = min(source_imgs.size(0), transform_imgs.size(0))
            source_imgs = source_imgs[:min_batch_size].to(device)
            transform_imgs = transform_imgs[:min_batch_size].to(device)

            for _ in range(critic_iterations):
                with torch.no_grad():  # Prevent gradient tracking
                    z = torch.randn(min_batch_size, 100, 1, 1, device=device)
                    fake_imgs = generator(z).detach()  # No gradient needed here

                real_loss = critic(transform_imgs).mean()
                fake_loss = critic(fake_imgs).mean()
                gp = compute_gradient_penalty(critic, transform_imgs, fake_imgs, device)
                critic_loss = fake_loss - real_loss + lambda_gp * gp

                optimizer_d.zero_grad()
                if torch.isnan(critic_loss) or torch.isinf(critic_loss):
                    continue
                critic_loss.backward()
                xm.optimizer_step(optimizer_d)
                xm.mark_step()

            z = torch.randn(min_batch_size, 100, 1, 1, device=device)
            fake_imgs = generator(z)
            generator_loss = -critic(fake_imgs).mean()

            optimizer_g.zero_grad()
            if torch.isnan(generator_loss) or torch.isinf(generator_loss):
                continue
            generator_loss.backward()
            xm.optimizer_step(optimizer_g)
            xm.mark_step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss D: {critic_loss.item()}, Loss G: {generator_loss.item()}", flush=True)

        wasserstein_distance = real_loss.item() - fake_loss.item()

        with open(csv_path, "a", newline="") as file:
            writer = csv.writer(file)
            writer.writerow([epoch+1, generator_loss.item(), critic_loss.item(), wasserstein_distance, gp.item(), real_loss.item(), fake_loss.item()])

        save_generated_images(generator, epoch + 1, device)
        if (epoch + 1) % 5 == 0:
            save_checkpoint(generator, critic, optimizer_g, optimizer_d, epoch)

    print("[+] Training ended", flush=True)

# Main script
if __name__ == "__main__":
    print(f"[+] Current working directory: {os.getcwd()}")

    device = xm.xla_device()  # Use TPU device
    #device = torch.device("cuda")
    print(f"[+] Using device: {device}")

    #csvpath = "/content/loss.csv"

    source_datasets = ["raddar/tuberculosis-chest-xrays-montgomery", "masoudnickparvar/brain-tumor-mri-dataset"]
    source_data_dirs = download_datasets(source_datasets)


    print("[+] Datasets downloaded successfully")

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    batch_size = 32
    num_epochs = 100
    lr = 0.0002
    extract_path = "/content/transformation"

    source_loader = load_multiple_datasets(source_data_dirs, transform, batch_size)
    transform_loader = DataLoader(datasets.ImageFolder(extract_path, transform=transform), batch_size=batch_size, shuffle=True)

    print("[+] Datasets loaded successfully")

    generator = Generator().to(device)
    critic = Critic().to(device)

    print("[+] Training begins")
    train_deepkeygen(generator, critic, source_loader, transform_loader, num_epochs, lr, device)
    print("[+] Training ended")




[+] Current working directory: /content
[+] Using device: xla:0
Downloading from https://www.kaggle.com/api/v1/datasets/download/raddar/tuberculosis-chest-xrays-montgomery?dataset_version_number=1...


100%|██████████| 585M/585M [00:03<00:00, 190MB/s]

Extracting files...





Downloading from https://www.kaggle.com/api/v1/datasets/download/masoudnickparvar/brain-tumor-mri-dataset?dataset_version_number=1...


100%|██████████| 149M/149M [00:00<00:00, 264MB/s]

Extracting files...





KeyboardInterrupt: 

# New Section

In [None]:
from google.colab import files

files.download("/content/checkpoints/deepkeygen_checkpoint.pth")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
!pip uninstall -y torch torch-xla
!pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118
!pip install cloud-tpu-client==0.10 torch_xla==2.0.1 --index-url https://download.pytorch.org/whl/cu118

In [None]:
!pip uninstall -y torch torchvision
!pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118

Found existing installation: torch 2.5.1+cpu
Uninstalling torch-2.5.1+cpu:
  Successfully uninstalled torch-2.5.1+cpu
Found existing installation: torchvision 0.20.1+cpu
Uninstalling torchvision-0.20.1+cpu:
  Successfully uninstalled torchvision-0.20.1+cpu
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch==2.0.1
  Downloading https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp311-cp311-linux_x86_64.whl (2267.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 GB[0m [31m?[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.15.2
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp311-cp311-linux_x86_64.whl (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m93.3 MB/s[0m eta [36m0:00:00[0m
Collecting triton==2.0.0 (from torch==2.0.1)
  Downloading https://download.pytorch.org/whl/triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64

In [None]:
!pip install --upgrade torch torchvision torchaudio


Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)
Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl (7.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m40.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchaudio-2.6.0-cp311-cp311-manylinux1_x86_64.whl (3.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m53.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchvision, torchaudio
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.20.1+cu124
    Uninstalling torchvision-0.20.1+cu124:
      Successfully uninstalled torchvision-0.20.1+cu124
  Attempting uninstall: torchaudio
    Found existing installation: torchaudio 2.5.1+cu124
    Uninstalling torchaudio-2.5.1+cu124:
      Successfu

In [None]:
!pip uninstall torch torch-xla
!pip install torch
!pip install torch-xla

In [None]:
import torch
import torch.nn as nn

device = xm.xla_device() if 'xm' in globals() else torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the Generator model with the same structure as the one used to save the checkpoint
import torch

device = xm.xla_device() if 'xm' in globals() else torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the Generator model with the same structure as the one used to save the checkpoint
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential( # Change 'main' to 'model'
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), # Adjust output channels to 32 to match the checkpoint
            nn.BatchNorm2d(32), # Add BatchNorm2d layer
            nn.ReLU(True), # Add ReLU layer
            nn.ConvTranspose2d(32, 16, 4, 2, 1, bias=False),  # (16, 128, 128)
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 4, 2, 1, bias=False),  # (3, 256, 256)
            nn.Tanh()
        )

    def forward(self, input):
        return self.model(input) # Change 'main' to 'model'





generator = Generator().to(device)


checkpoint_path = "/content/drive/MyDrive/Sem6/CIP_Team6_2025/Final_Training/deepkeygen_checkpoint__30-35__epoch.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
generator.load_state_dict(checkpoint["generator_state_dict"])
generator.eval()  # Switch to evaluation mode

print("Model loaded successfully.")

AttributeError: Can't get attribute '_rebuild_device_tensor_from_cpu_tensor' on <module 'torch._utils' from '/usr/local/lib/python3.11/dist-packages/torch/_utils.py'>

In [None]:
import torch
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

# Load the medical image
image_path = "/content/drive/MyDrive/Sem6/CIP_Team6_2025/Medical_Images/MCUCXR_0001_0.png"
image = Image.open(image_path).convert("L")

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Apply transformations
image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension

z = torch.randn(1, 100, 1, 1, device=device)

# Reshape the noise vector to have the same spatial dimensions as the image tensor
z = z.repeat(1, 1, image_tensor.shape[2], image_tensor.shape[3]) # Repeat z along spatial dimensions

# Instead of concatenating, only use the noise vector 'z' as input to the generator:
# with torch.no_grad():
#     fake_image = generator(z)

# Pass only the noise vector (z) to the generator
with torch.no_grad():
    fake_image = generator(z)


os.makedirs("generated_images", exist_ok=True)
image_path = "generated_images/generated_sample.png"
# Make sure utils is imported from torchvision
from torchvision import utils
utils.save_image(fake_image, image_path, normalize=True)

print(f"Generated image saved at: {image_path}")

KeyboardInterrupt: 

In [None]:
!pip uninstall -y torch torchvision
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Found existing installation: torch 2.6.0
Uninstalling torch-2.6.0:
  Successfully uninstalled torch-2.6.0
Found existing installation: torchvision 0.21.0
Uninstalling torchvision-0.21.0:
  Successfully uninstalled torchvision-0.21.0
[0mLooking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (27 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.21.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m61.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118