In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

# Dataset Class
class XRayDataset(Dataset):
    def __init__(self, csv_file, img_folder, transform=None):
        self.csv = pd.read_csv(csv_file)
        self.img_folder = img_folder
        self.transform = transform
        self.pairs = self.csv.groupby('uid')  
        self.valid_uids = self.check_integrity()

    def check_integrity(self):
        valid_uids = []
        for uid, group in self.pairs:
            frontal_data = group[group['projection'] == 'Frontal']
            lateral_data = group[group['projection'] == 'Lateral']
            if frontal_data.empty or lateral_data.empty:
                continue
            valid_uids.append(uid)
        return valid_uids

    def __len__(self):
        return len(self.valid_uids)

    def __getitem__(self, idx):
        group_key = self.valid_uids[idx]
        uid = self.pairs.get_group(group_key)

        frontal_img_path = os.path.join(self.img_folder, uid[uid['projection'] == 'Frontal']['filename'].values[0])
        lateral_img_path = os.path.join(self.img_folder, uid[uid['projection'] == 'Lateral']['filename'].values[0])
        frontal_img = Image.open(frontal_img_path).convert("RGB")
        lateral_img = Image.open(lateral_img_path).convert("RGB")

        if self.transform:
            frontal_img = self.transform(frontal_img)
            lateral_img = self.transform(lateral_img)

        return frontal_img, lateral_img

# Define the Dense Block
class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1),
                nn.InstanceNorm2d(growth_rate),
                nn.ReLU(inplace=True)
            ))

    def forward(self, x):
        for layer in self.layers:
            new_features = layer(x)
            x = torch.cat([x, new_features], dim=1)
        return x

# Define the Basic3D Block
class Basic3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Basic3D, self).__init__()
        self.block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

# Define 2D to 3D Connection (Connection-C)
class Connection2Dto3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Connection2Dto3D, self).__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv3d = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv2d(x)
        x = x.unsqueeze(2)
        x = self.conv3d(x)
        return x


# Updated Discriminator to be self-adjusting
class Discriminator(nn.Module):
    def __init__(self, in_channels, base_channels):
        super(Discriminator, self).__init__()

        self.layers = nn.ModuleList()
        self.layers.append(
            nn.Sequential(
                nn.Conv3d(in_channels, base_channels, kernel_size=2, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            )
        )

        # Add subsequent convolutional layers
        num_layers = 4  # Number of convolutional layers
        for i in range(1, num_layers):
            self.layers.append(
                nn.Sequential(
                    nn.Conv3d(base_channels * (2 ** (i - 1)), base_channels * (2 ** i), kernel_size=2, stride=2, padding=1),
                    nn.InstanceNorm3d(base_channels * (2 ** i)),
                    nn.LeakyReLU(0.2, inplace=True)
                )
            )

        # Final layer
        self.final_layer = nn.Conv3d(base_channels * (2 ** (num_layers - 1)), 1, kernel_size=2, stride=1, padding=0)
        self.pooling = nn.AdaptiveAvgPool3d(1)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        # Ensure final output is flattened
        x = self.final_layer(x)
        x = self.pooling(x)
        return x.view(-1)

# Updated Generator to verify and propagate image size
class Generator(nn.Module):
    def __init__(self, in_channels, growth_rate, num_dense_layers, out_channels, image_size):
        super(Generator, self).__init__()

        self.image_size = image_size
        self.num_upconv_layers = int(torch.log2(torch.tensor(image_size // 4)).item())

        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1),
            DenseBlock(growth_rate, growth_rate, num_dense_layers),
            nn.Conv2d(growth_rate * (num_dense_layers + 1), growth_rate, kernel_size=3, stride=2, padding=1)  # Compression
        )

        self.encoder2 = nn.Sequential(
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1),
            DenseBlock(growth_rate, growth_rate, num_dense_layers),
            nn.Conv2d(growth_rate * (num_dense_layers + 1), growth_rate, kernel_size=3, stride=2, padding=1)  # Compression
        )

        self.connection_a = Connection2Dto3D(growth_rate, growth_rate)
        self.connection_b = nn.Conv3d(growth_rate * 2, growth_rate, kernel_size=3, padding=1)
        self.connection_c = nn.Conv3d(growth_rate, growth_rate, kernel_size=3, padding=1)

        self.upconv_layers = nn.ModuleList()
        for _ in range(self.num_upconv_layers):
            self.upconv_layers.append(
                nn.Sequential(
                    nn.ConvTranspose3d(growth_rate, growth_rate, kernel_size=4, stride=2, padding=1),
                    nn.InstanceNorm3d(growth_rate),
                    nn.ReLU(inplace=True)
                )
            )

        self.final_layer = nn.Conv3d(growth_rate, out_channels, kernel_size=3, padding=1)

    def forward(self, x1, x2):

        x1 = self.encoder1(x1)
        x2 = self.encoder2(x2)
        x1 = self.connection_a(x1)
        x2 = self.connection_a(x2)
        x = torch.cat([x1, x2], dim=1)
        x = self.connection_b(x)
        for upconv in self.upconv_layers:
            x = upconv(x)

        x = self.final_layer(x)
        return x
    
    
if __name__ == "__main__":

    # Paths and Hyperparameters
    IMAGE_FOLDER = '/projectnb/ec523kb/projects/teams_Fall_2024/Team_11/Adwait/Work_on_this_code/images/images_normalized/'
    CSV_FILE = '/projectnb/ec523kb/projects/teams_Fall_2024/Team_11/Adwait/Work_on_this_code/indiana_projections.csv'
    BATCH_SIZE = 16
    EPOCHS = 10
    IMAGE_SIZE = 100
    LATENT_DIM = 100
    IN_CHANNELS = 3
    OUT_CHANNELS = 1
    GROWTH_RATE = 16
    NUM_DENSE_LAYERS = 4
    BASE_CHANNELS = 64

    d_real_losses = []
    d_fake_losses = []
    g_losses = []

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor()
    ])

    dataset = XRayDataset(CSV_FILE, IMAGE_FOLDER, transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Initialize Generator and Discriminator
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = Generator(
        in_channels=IN_CHANNELS,
        growth_rate=GROWTH_RATE,
        num_dense_layers=NUM_DENSE_LAYERS,
        out_channels=OUT_CHANNELS,
        image_size=IMAGE_SIZE
    ).to(device)
    generator = nn.DataParallel(generator)

    discriminator = Discriminator(OUT_CHANNELS, BASE_CHANNELS).to(device)
    discriminator = nn.DataParallel(discriminator)

    g_optim = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optim = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    g_criterion = nn.MSELoss()
    d_criterion = nn.BCEWithLogitsLoss()

    # Training Loop
    for epoch in range(EPOCHS):
        generator.train()
        discriminator.train()
        g_epoch_loss = 0
        d_epoch_loss = 0

        with tqdm(dataloader, unit="batch") as tepoch:
            for frontal_img, lateral_img in tepoch:
                tepoch.set_description(f"Epoch {epoch + 1}")

                frontal_img = frontal_img.to(device)
                lateral_img = lateral_img.to(device)
                real_labels = torch.ones((frontal_img.size(0),), device=device)
                fake_labels = torch.zeros((frontal_img.size(0),), device=device)
                # ===================== Train Discriminator =====================
                # Generate fake samples
                fake_volumes = generator(frontal_img, lateral_img)

                # Real samples
                real_outputs = discriminator(fake_volumes.detach())
                real_loss = d_criterion(real_outputs, real_labels)

                # Fake samples
                fake_outputs = discriminator(fake_volumes.detach())
                fake_loss = d_criterion(fake_outputs, fake_labels)

                # Backpropagation for Discriminator
                d_loss = real_loss + fake_loss
                d_optim.zero_grad()
                d_loss.backward()
                d_optim.step()
                d_real_losses.append(real_loss.item())
                d_fake_losses.append(fake_loss.item())
                d_epoch_loss += d_loss.item()

                # ===================== Train Generator =====================
                fake_outputs = discriminator(fake_volumes)
                g_loss = g_criterion(fake_outputs, real_labels)

                g_optim.zero_grad()
                g_loss.backward()
                g_optim.step()

                g_epoch_loss += g_loss.item()
                g_losses.append(g_loss.item())

                # Ground truths


                tepoch.set_postfix(d_loss=d_loss.item(), g_loss=g_loss.item(),D_Real_Loss = real_loss.item(), D_Fake_Loss = fake_loss.item())

    print("Training finished.")

    # Save Models
    torch.save(generator.state_dict(), f"generator_{BATCH_SIZE}_{IMAGE_SIZE}_epoch_{EPOCHS}.pth")
    torch.save(discriminator.state_dict(), f"discriminator_{BATCH_SIZE}_{IMAGE_SIZE}_epoch_{EPOCHS}.pth")

    print(f"Models saved at epoch {EPOCHS} with batch size {BATCH_SIZE} and image size {IMAGE_SIZE}")
    plt.figure(figsize=(10, 6))
    plt.plot(d_real_losses, label='Discriminator Real Loss', color='r')
    plt.plot(d_fake_losses, label='Discriminator Fake Loss', color='orange')
    plt.plot(g_losses, label='Generator Loss', color='b')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title('Loss Curves during Training')

    plt.legend()
    plt.tight_layout()
    plt.savefig(f'losses_epoch_{EPOCHS}_batch_{BATCH_SIZE}_image_{IMAGE_SIZE}_3DGANDiscriminator.png')



Epoch 1: 100%|██████████| 212/212 [12:38<00:00,  3.58s/batch, D_Fake_Loss=0.674, D_Real_Loss=0.714, d_loss=1.39, g_loss=0.726]
  0%|          | 0/212 [00:00<?, ?batch/s]

Epoch 1, Generator Loss: 222.0677, Discriminator Loss: 294.9852


Epoch 2: 100%|██████████| 212/212 [12:03<00:00,  3.41s/batch, D_Fake_Loss=0.701, D_Real_Loss=0.684, d_loss=1.38, g_loss=1.17] 
  0%|          | 0/212 [00:00<?, ?batch/s]

Epoch 2, Generator Loss: 213.1370, Discriminator Loss: 294.0971


Epoch 3:  48%|████▊     | 101/212 [05:49<06:24,  3.46s/batch, D_Fake_Loss=0.673, D_Real_Loss=0.713, d_loss=1.39, g_loss=0.887]

In [None]:
#AROUND 12 Minutes per epoch so take 15 mins for each epoch while submitting a batch job :)
#ALSO this config takes nearly 70gb vram so its like A100 or 2 L40S.