In [1]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt


# Folder containing your images
IMAGE_FOLDER = '/projectnb/ec523kb/projects/teams_Fall_2024/Team_11/Adwait/Work_on_this_code/images/images_normalized/'

# CSV file path
CSV_FILE = '/projectnb/ec523kb/projects/teams_Fall_2024/Team_11/Adwait/Work_on_this_code/indiana_projections.csv'

# Hyperparameters
BATCH_SIZE = 128
EPOCHS = 50
LATENT_DIM = 1024
LEARNING_RATE = 0.0002
IMAGE_SIZE = 1440  # Reduced image size for 3D reconstruction

d_real_losses = []
d_fake_losses = []
g_losses = []

# Define device (cuda if available, otherwise cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset class remains the same
# Define dataset class
class XRayDataset(Dataset):
    def __init__(self, csv_file, img_folder, transform=None):
        self.csv = pd.read_csv(csv_file)
        self.img_folder = img_folder
        self.transform = transform
        self.pairs = self.csv.groupby('uid')  # Group by UID
        self.valid_uids = self.check_integrity()  # Ensure the integrity of the dataset

    def check_integrity(self):
        valid_uids = []
        for uid, group in self.pairs:
            frontal_data = group[group['projection'] == 'Frontal']
            lateral_data = group[group['projection'] == 'Lateral']
            if frontal_data.empty or lateral_data.empty:
                continue
            valid_uids.append(uid)
        return valid_uids

    def __len__(self):
        return len(self.valid_uids)

    def __getitem__(self, idx):
        group_key = self.valid_uids[idx]
        uid = self.pairs.get_group(group_key)

        # Find the paths for the frontal and lateral images
        frontal_img_path = os.path.join(self.img_folder, uid[uid['projection'] == 'Frontal']['filename'].values[0])
        lateral_img_path = os.path.join(self.img_folder, uid[uid['projection'] == 'Lateral']['filename'].values[0])

        # Load images
        frontal_img = Image.open(frontal_img_path).convert("RGB")
        lateral_img = Image.open(lateral_img_path).convert("RGB")

        if self.transform:
            frontal_img = self.transform(frontal_img)
            lateral_img = self.transform(lateral_img)

        return frontal_img, lateral_img

# Define transforms for image preprocessing
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create the dataset and dataloader
dataset = XRayDataset(csv_file=CSV_FILE, img_folder=IMAGE_FOLDER, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

class Generator3D(nn.Module):
    def __init__(self, latent_dim):
        super(Generator3D, self).__init__()
        self.latent_dim = latent_dim
        
        self.conv1 = nn.ConvTranspose3d(latent_dim, 512, 4, 1, 0, bias=False)
        self.bn1 = nn.BatchNorm3d(512)
        self.conv2 = nn.ConvTranspose3d(512, 256, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm3d(256)
        self.conv3 = nn.ConvTranspose3d(256, 128, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm3d(128)
        self.conv4 = nn.ConvTranspose3d(128, 64, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm3d(64)
        self.conv5 = nn.ConvTranspose3d(64, 1, 4, 2, 1, bias=False)
        
    def forward(self, z, frontal_img, lateral_img):
        x = z.view(-1, self.latent_dim, 1, 1, 1)
        x = nn.ReLU()(self.bn1(self.conv1(x)))
        x = nn.ReLU()(self.bn2(self.conv2(x)))
        x = nn.ReLU()(self.bn3(self.conv3(x)))
        x = nn.ReLU()(self.bn4(self.conv4(x)))
        x = torch.tanh(self.conv5(x))
        # Ensure output size is 64x64x64
        x = nn.functional.interpolate(x, size=(64, 64, 64), mode='trilinear', align_corners=False)
        return x


# --- 3D Discriminator ---
class Discriminator3D(nn.Module):
    def __init__(self):
        super(Discriminator3D, self).__init__()
        
        self.conv1 = nn.Conv3d(7, 64, 4, 2, 1, bias=False)
        self.conv2 = nn.Conv3d(64, 128, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm3d(128)
        self.conv3 = nn.Conv3d(128, 256, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm3d(256)
        self.conv4 = nn.Conv3d(256, 512, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm3d(512)
        self.conv5 = nn.Conv3d(512, 1, 4, 1, 0, bias=False)
        
    def forward(self, volume, frontal_img, lateral_img):
        # Adjust 2D images to match 3D volume dimensions
        frontal_3d = frontal_img.unsqueeze(2).expand(-1, -1, volume.size(2), -1, -1)
        lateral_3d = lateral_img.unsqueeze(3).expand(-1, -1, -1, volume.size(3), -1)
        
        # Ensure all tensors have the same size
        volume = nn.functional.interpolate(volume, size=(64, 64, 64), mode='trilinear', align_corners=False)
        frontal_3d = nn.functional.interpolate(frontal_3d, size=(64, 64, 64), mode='trilinear', align_corners=False)
        lateral_3d = nn.functional.interpolate(lateral_3d, size=(64, 64, 64), mode='trilinear', align_corners=False)
        
        x = torch.cat([volume, frontal_3d, lateral_3d], dim=1)
        x = nn.LeakyReLU(0.2)(self.conv1(x))
        x = nn.LeakyReLU(0.2)(self.bn2(self.conv2(x)))
        x = nn.LeakyReLU(0.2)(self.bn3(self.conv3(x)))
        x = nn.LeakyReLU(0.2)(self.bn4(self.conv4(x)))
        x = torch.sigmoid(self.conv5(x))
        return x.view(-1, 1)


# Initialize models and move them to the GPU
generator = Generator3D(latent_dim=LATENT_DIM).to(device)
discriminator = Discriminator3D().to(device)

# Use DataParallel for multi-GPU setup
# generator = nn.DataParallel(generator)
# discriminator = nn.DataParallel(discriminator)

# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Loss function
adversarial_loss = nn.BCELoss()

# Training loop
for epoch in range(EPOCHS):
    generator.train()
    discriminator.train()

    # Initialize progress bar
    with tqdm(dataloader, unit="batch") as tepoch:
        for i, (frontal_img, lateral_img) in enumerate(tepoch):
            # Move images to GPU
            frontal_img, lateral_img = frontal_img.to(device), lateral_img.to(device)

            # Prepare real and fake labels
            real_label = torch.ones(frontal_img.size(0), 1).to(device)
            fake_label = torch.zeros(frontal_img.size(0), 1).to(device)

            # Train Discriminator
            optimizer_d.zero_grad()

            # Generate fake 3D volume
            z = torch.randn(frontal_img.size(0), LATENT_DIM).to(device)
            fake_volume = generator(z, frontal_img, lateral_img)

            # Real images (using fake_volume as placeholder for real 3D volume)
            real_pred = discriminator(fake_volume.detach(), frontal_img, lateral_img)
            d_real_loss = adversarial_loss(real_pred, real_label)

            # Fake images
            fake_pred = discriminator(fake_volume.detach(), frontal_img, lateral_img)
            d_fake_loss = adversarial_loss(fake_pred, fake_label)

            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            optimizer_d.step()
            d_real_losses.append(d_real_loss.item())
            d_fake_losses.append(d_fake_loss.item())

            # Train Generator
            optimizer_g.zero_grad()
            fake_pred = discriminator(fake_volume, frontal_img, lateral_img)
            g_loss = adversarial_loss(fake_pred, real_label)
            g_loss.backward()
            optimizer_g.step()

            g_losses.append(g_loss.item())

            tepoch.set_postfix(D_loss=d_loss.item(),D_Real_Loss = d_real_loss.item(),D_Fake_Loss = d_fake_loss.item(), G_loss=g_loss.item())

print("Training finished.")
torch.save(generator.state_dict(), f"generator3DNONLINEAR_{BATCH_SIZE}_{IMAGE_SIZE}_epoch_{EPOCHS}.pth")
torch.save(discriminator.state_dict(), f"discriminator3DNONLINEAR_{BATCH_SIZE}_{IMAGE_SIZE}_epoch_{EPOCHS}.pth")

print(f"Models saved at epoch {EPOCHS} with batch size {BATCH_SIZE} and image size {IMAGE_SIZE}")



plt.figure(figsize=(10, 6))
plt.plot(d_real_losses, label='Discriminator Real Loss', color='r')
plt.plot(d_fake_losses, label='Discriminator Fake Loss', color='orange')
plt.plot(g_losses, label='Generator Loss', color='b')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.title('Loss Curves during Training')

plt.legend()
plt.tight_layout()
plt.savefig(f'combined_losses_epoch_{EPOCHS}_batch_{BATCH_SIZE}_image_{IMAGE_SIZE}_NNL.png')


 11%|â–ˆ         | 3/27 [01:46<14:15, 35.64s/batch, D_Fake_Loss=1.44, D_Real_Loss=1.44, D_loss=2.87, G_loss=5.28]     


KeyboardInterrupt: 