In [None]:
import torch
import torch.nn as nn
from torchsummary import summary
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from dataloader import dataset
from matplotlib import pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


In [None]:
class Conv_BN_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Conv_BN_ReLU, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=1)
        self.norm = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)

        return x


In [None]:
class DeConv_BN_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(DeConv_BN_ReLU, self).__init__()
        self.deconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()
    
    def forward(self,x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.relu(x)

        return x

In [None]:
class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()
        #
        #self.encoder1 = Conv_BN_ReLU(in_channels, 64)
        self.conv1 = nn.Conv3d(in_channels,64, kernel_size=3, padding=1, stride=1)
        self.norm1 = nn.BatchNorm3d(64)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        #self.encoder2 = Conv_BN_ReLU(64, 128)
        self.conv2 = nn.Conv3d(64,128, kernel_size=3, padding=1, stride=1)
        self.norm2 = nn.BatchNorm3d(128)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        #self.encoder3_1 = Conv_BN_ReLU(128, 256)
        self.conv3_1 = nn.Conv3d(128,256, kernel_size=3, padding=1, stride=1)
        self.norm3_1 = nn.BatchNorm3d(256)
        self.relu3_1 = nn.ReLU()
        #self.encoder3_2 = Conv_BN_ReLU(256, 256)
        self.conv3_2 = nn.Conv3d(256,256, kernel_size=3, padding=1, stride=1)
        self.norm3_2 = nn.BatchNorm3d(256)
        self.relu3_2 = nn.ReLU()
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        #self.encoder4_1 = Conv_BN_ReLU(256, 512)
        self.conv4_1 = nn.Conv3d(256,512, kernel_size=3, padding=1, stride=1)
        self.norm4_1 = nn.BatchNorm3d(512)
        self.relu4_1 = nn.ReLU()
        self.encoder4_2 = Conv_BN_ReLU(512, 512)
        self.conv4_2 = nn.Conv3d(512,512, kernel_size=3, padding=1, stride=1)
        self.norm4_2 = nn.BatchNorm3d(512)
        self.relu4_2 = nn.ReLU()
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        self.bottleneck1 = Conv_BN_ReLU(512,512)
        self.bottleneck2 = Conv_BN_ReLU(512,512) 
        #
        self.decoder1_1 = DeConv_BN_ReLU(512,512) 
        self.decoder1_2 = Conv_BN_ReLU(1024,256)
        #
        self.decoder2_1 = DeConv_BN_ReLU(256,256)
        self.decoder2_2 = Conv_BN_ReLU(512,128)
        #
        self.decoder3_1 = DeConv_BN_ReLU(128,128)
        self.decoder3_2 = Conv_BN_ReLU(256,64)
        #
        self.decoder4_1 = DeConv_BN_ReLU(64,64)
        self.decoder4_2 = Conv_BN_ReLU(128,32)
        #
        self.final_conv = nn.Conv3d(32, out_channels, kernel_size=1, stride=1)


    def forward(self, x):
        #enc1 = self.encoder1(x)
        conv1 = self.conv1(x)
        norm1 = self.norm1(conv1)
        relu1 = self.relu1(norm1)
        pool1 = self.pool1(relu1)
        #
        #enc2 = self.encoder2(self.pool1(enc1))
        conv2 = self.conv2(pool1)
        norm2 = self.norm2(conv2)
        relu2 = self.relu2(norm2)
        pool2 = self.pool2(relu2)
        #
        #enc3_1 = self.encoder3_1(self.pool2(enc2))
        conv3_1 = self.conv3_1(pool2)
        norm3_1 = self.norm3_1(conv3_1)
        relu3_1 = self.relu3_1(norm3_1)
        #enc3_2 = self.encoder3_2(enc3_1)
        conv3_2 = self.conv3_2(relu3_1)
        norm3_2 = self.norm3_2(conv3_2)
        relu3_2 = self.relu3_2(norm3_2)
        pool3 = self.pool3(relu3_2)
        #
        #enc4_1 = self.encoder4_1(self.pool3(enc3_2))
        conv4_1 = self.conv4_1(pool3)
        norm4_1 = self.norm4_1(conv4_1)
        relu4_1 = self.relu4_1(norm4_1)
        #enc4_2 = self.encoder4_2(self.pool3(enc4_1))
        conv4_2 = self.conv4_2(relu4_1)
        norm4_2 = self.norm4_2(conv4_2)
        relu4_2 = self.relu4_2(norm4_2)
        pool4 = self.pool4(relu4_2) #[bs, 512, 4, 4, 4]

        bottleneck1 = self.bottleneck1(pool4) #[bs, 512, 4, 4, 4]
        bottleneck2 = self.bottleneck2(bottleneck1) #[bs, 512, 4, 4, 4]

        dec1_1 = self.decoder1_1(bottleneck2) #[bs, 512, 4, 4, 4]
        con1 = torch.cat((dec1_1,conv4_2),dim=1) #[bs, 1024, 4, 4, 4]
        dec1_2 = self.decoder1_2(con1) #[bs, 256, 8, 8, 8]
        
        dec2_1 = self.decoder2_1(dec1_2) #[bs, 256, 16, 16, 16]
        con2 = torch.cat((dec2_1,conv3_2),dim=1)
        dec2_2 = self.decoder2_2(con2) #[bs, 128, 16, 16, 16]

        dec3_1 = self.decoder3_1(dec2_2)
        con3 = torch.cat((dec3_1,conv2),dim=1)
        dec3_2 = self.decoder3_2(con3)

        dec4_1 = self.decoder4_1(dec3_2)
        con4 = torch.cat((dec4_1,conv1),dim=1)
        dec4_2 = self.decoder4_2(con4) #[bs, 32, 64, 64, 64]

        output = self.final_conv(dec4_2)
        
        return output

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv3d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv3d(256, 512, kernel_size=4, stride=2, padding=1)
        self.linear = nn.Linear(32768, 1)

        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.leaky_relu(self.conv1(x))
        x = self.leaky_relu(self.conv2(x))
        x = self.leaky_relu(self.conv3(x))
        x = self.leaky_relu(self.conv4(x))
        
        x = x.view(x.size(0), -1)  
        
        x = self.linear(x)  
        x = self.sigmoid(x)

        return x

In [None]:
tensor1 = torch.randn((8, 2, 64, 64, 64)).to(device)
tensor2 = torch.randn((8, 2, 64, 64, 64)).to(device)
print(torch.concat((tensor1,tensor2),dim=1).shape)

In [None]:
dis = Discriminator(in_channels=2).to(device)
input_tensor = torch.randn((2, 2, 64, 64, 64)).to(device)
out = dis(input_tensor)
out.shape
summary(dis, input_size=(2, 64, 64, 64))

In [None]:
class cGAN_IM2IM(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(cGAN_IM2IM, self).__init__()
        self.in_channel = in_channels
        self.out_channel = out_channels
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()

    def build_generator(self):
        return UNet3D(self.in_channel,self.out_channel)

    def build_discriminator(self):
        return Discriminator(self.in_channel*2)
    
    def forward(self, x):
        generated_images = self.generator(x)
        discriminator_output = self.discriminator(generated_images)
        return generated_images, discriminator_output

In [None]:
Trainingset = dataset(file_path1="./reg_data/00/",file_path2="./reg_data/12/",force=0,start_index=144,end_index=720)
trainingloader = DataLoader(dataset=Trainingset,batch_size=8,shuffle=True)

Testingset = dataset(file_path1="./reg_data/00/",file_path2="./reg_data/12/",force=0,start_index=0,end_index=144)
testloader = DataLoader(dataset=Testingset,batch_size=8,shuffle=True)

In [None]:
def eval_cgan_im2im(model, val_loader, criterion):
    model.eval()
    loss_total = 0.0
    with torch.no_grad():
        for inputs, targets,_ in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.unsqueeze(1).float()
            targets = targets.unsqueeze(1).float()

            generated_images = model.generator(inputs)
            loss = criterion(generated_images,targets)
            loss_total += loss.item() * inputs.size(0)

        loss_total /= len(val_loader.dataset)
    return loss_total

In [None]:
def train_cgan_im2im(model, train_loader, val_loader, num_epochs, device, lr=0.0002, beta1=0.5, beta2=0.999):
    
    optimizer_gen = optim.Adam(model.generator.parameters(), lr=lr, betas=(beta1, beta2))
    optimizer_disc = optim.Adam(model.discriminator.parameters(), lr=lr, betas=(beta1, beta2))
    criterion_BCE = nn.BCELoss()
    #criterion = nn.L1Loss()
    criterion = nn.MSELoss()

    model.to(device)
    gen_losses = []
    sim_losses = []
    dis_losses = []
    val_losses = []
    min_val = 1.0

    for epoch in range(num_epochs):
        gen_loss_total = 0.0
        sim_loss_total = 0.0
        disc_loss_total = 0.0
        model.train()
        for inputs, targets,_ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.unsqueeze(1).float()
            targets = targets.unsqueeze(1).float()

            optimizer_disc.zero_grad()

            real_labels = torch.ones(inputs.size(0), 1).to(device)
            fake_labels = torch.zeros(inputs.size(0), 1).to(device)

            generated_images = model.generator(inputs)
            real_outputs = model.discriminator(torch.concat((inputs,targets),dim=1))
            fake_outputs = model.discriminator(torch.concat((inputs,generated_images.detach()),dim=1))

            disc_loss_real = criterion_BCE(real_outputs, real_labels)
            disc_loss_fake = criterion_BCE(fake_outputs, fake_labels)
            disc_loss = disc_loss_real + disc_loss_fake

            # if disc_loss.item() > 0.01:
            disc_loss.backward()
            optimizer_disc.step()


            optimizer_gen.zero_grad()

            generated_images = model.generator(inputs)
            fake_outputs = model.discriminator(torch.concat((inputs,generated_images.detach()),dim=1))

            gen_loss1 = criterion_BCE(fake_outputs, real_labels)
            gen_loss2 = criterion(generated_images,targets)
            gen_loss = gen_loss1 + 100*gen_loss2
            # if gen_loss.item() > 0.01:
            gen_loss.backward()
            optimizer_gen.step()

            gen_loss_total += gen_loss1.item() * inputs.size(0)
            sim_loss_total += gen_loss2.item() * inputs.size(0)
            disc_loss_total += disc_loss.item() * inputs.size(0)

        gen_loss_total /= len(train_loader.dataset)
        gen_losses.append(gen_loss_total)
        sim_loss_total /= len(train_loader.dataset)
        sim_losses.append(sim_loss_total)
        disc_loss_total /= len(train_loader.dataset)
        dis_losses.append(disc_loss_total)


        val_loss = eval_cgan_im2im(model,val_loader,criterion)
        val_losses.append(val_loss)

        if val_loss < min_val:
            min_val = val_loss
            if val_loss < 0.05:
                torch.save(model.state_dict(), f'./saved_model/cGAN/{epoch}_{val_loss:.4f}_lr1e-6.pth')

        print(f"Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss_total:.4f}, Similarity Loss: {sim_loss_total:.4f}, Discriminator Loss: {disc_loss_total:.4f},val Loss: {val_loss:.4f}")

    print('Finished Training')

    return gen_losses, sim_losses, dis_losses, val_losses


In [None]:
gan_model = cGAN_IM2IM(1, 1).to(device)

In [None]:
gen_loss, sim_loss, dis_loss, val_loss = train_cgan_im2im(gan_model, trainingloader, testloader, 400, device, lr=0.000001, beta1=0.5, beta2=0.999)