In [None]:
import torch
import itertools
import torch.nn as nn
from torchsummary import summary
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from dataloader import dataset
from matplotlib import pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
class Conv_BN_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Conv_BN_ReLU, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=1)
        self.norm = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)

        return x


In [None]:
class DeConv_BN_ReLU(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(DeConv_BN_ReLU, self).__init__()
        self.deconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU()
    
    def forward(self,x):
        x = self.deconv(x)
        x = self.bn(x)
        x = self.relu(x)

        return x

In [None]:
class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3D, self).__init__()
        #
        #self.encoder1 = Conv_BN_ReLU(in_channels, 64)
        self.conv1 = nn.Conv3d(in_channels,64, kernel_size=3, padding=1, stride=1)
        self.norm1 = nn.BatchNorm3d(64)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        #self.encoder2 = Conv_BN_ReLU(64, 128)
        self.conv2 = nn.Conv3d(64,128, kernel_size=3, padding=1, stride=1)
        self.norm2 = nn.BatchNorm3d(128)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        #self.encoder3_1 = Conv_BN_ReLU(128, 256)
        self.conv3_1 = nn.Conv3d(128,256, kernel_size=3, padding=1, stride=1)
        self.norm3_1 = nn.BatchNorm3d(256)
        self.relu3_1 = nn.ReLU()
        #self.encoder3_2 = Conv_BN_ReLU(256, 256)
        self.conv3_2 = nn.Conv3d(256,256, kernel_size=3, padding=1, stride=1)
        self.norm3_2 = nn.BatchNorm3d(256)
        self.relu3_2 = nn.ReLU()
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        #self.encoder4_1 = Conv_BN_ReLU(256, 512)
        self.conv4_1 = nn.Conv3d(256,512, kernel_size=3, padding=1, stride=1)
        self.norm4_1 = nn.BatchNorm3d(512)
        self.relu4_1 = nn.ReLU()
        self.encoder4_2 = Conv_BN_ReLU(512, 512)
        self.conv4_2 = nn.Conv3d(512,512, kernel_size=3, padding=1, stride=1)
        self.norm4_2 = nn.BatchNorm3d(512)
        self.relu4_2 = nn.ReLU()
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
        #
        self.bottleneck1 = Conv_BN_ReLU(512,512)
        self.bottleneck2 = Conv_BN_ReLU(512,512) 
        #
        self.decoder1_1 = DeConv_BN_ReLU(512,512) 
        self.decoder1_2 = Conv_BN_ReLU(1024,256)
        #
        self.decoder2_1 = DeConv_BN_ReLU(256,256)
        self.decoder2_2 = Conv_BN_ReLU(512,128)
        #
        self.decoder3_1 = DeConv_BN_ReLU(128,128)
        self.decoder3_2 = Conv_BN_ReLU(256,64)
        #
        self.decoder4_1 = DeConv_BN_ReLU(64,64)
        self.decoder4_2 = Conv_BN_ReLU(128,32)
        #
        self.final_conv = nn.Conv3d(32, out_channels, kernel_size=1, stride=1)


    def forward(self, x):
        #enc1 = self.encoder1(x)
        conv1 = self.conv1(x)
        norm1 = self.norm1(conv1)
        relu1 = self.relu1(norm1)
        pool1 = self.pool1(relu1)
        #
        #enc2 = self.encoder2(self.pool1(enc1))
        conv2 = self.conv2(pool1)
        norm2 = self.norm2(conv2)
        relu2 = self.relu2(norm2)
        pool2 = self.pool2(relu2)
        #
        #enc3_1 = self.encoder3_1(self.pool2(enc2))
        conv3_1 = self.conv3_1(pool2)
        norm3_1 = self.norm3_1(conv3_1)
        relu3_1 = self.relu3_1(norm3_1)
        #enc3_2 = self.encoder3_2(enc3_1)
        conv3_2 = self.conv3_2(relu3_1)
        norm3_2 = self.norm3_2(conv3_2)
        relu3_2 = self.relu3_2(norm3_2)
        pool3 = self.pool3(relu3_2)
        #
        #enc4_1 = self.encoder4_1(self.pool3(enc3_2))
        conv4_1 = self.conv4_1(pool3)
        norm4_1 = self.norm4_1(conv4_1)
        relu4_1 = self.relu4_1(norm4_1)
        #enc4_2 = self.encoder4_2(self.pool3(enc4_1))
        conv4_2 = self.conv4_2(relu4_1)
        norm4_2 = self.norm4_2(conv4_2)
        relu4_2 = self.relu4_2(norm4_2)
        pool4 = self.pool4(relu4_2) #[bs, 512, 4, 4, 4]

        bottleneck1 = self.bottleneck1(pool4) #[bs, 512, 4, 4, 4]
        bottleneck2 = self.bottleneck2(bottleneck1) #[bs, 512, 4, 4, 4]

        dec1_1 = self.decoder1_1(bottleneck2) #[bs, 512, 4, 4, 4]
        con1 = torch.cat((dec1_1,conv4_2),dim=1) #[bs, 1024, 4, 4, 4]
        dec1_2 = self.decoder1_2(con1) #[bs, 256, 8, 8, 8]
        
        dec2_1 = self.decoder2_1(dec1_2) #[bs, 256, 16, 16, 16]
        con2 = torch.cat((dec2_1,conv3_2),dim=1)
        dec2_2 = self.decoder2_2(con2) #[bs, 128, 16, 16, 16]

        dec3_1 = self.decoder3_1(dec2_2)
        con3 = torch.cat((dec3_1,conv2),dim=1)
        dec3_2 = self.decoder3_2(con3)

        dec4_1 = self.decoder4_1(dec3_2)
        con4 = torch.cat((dec4_1,conv1),dim=1)
        dec4_2 = self.decoder4_2(con4) #[bs, 32, 64, 64, 64]

        output = self.final_conv(dec4_2)
        
        return output

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.conv1 = Conv_BN_ReLU(in_channels,32)
        self.conv2 = Conv_BN_ReLU(32,64)
        self.conv3 = Conv_BN_ReLU(64,64)
        self.conv4 = Conv_BN_ReLU(64,64)
        self.linear = nn.Linear(16777216, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = x.view(x.size(0), -1)  
        
        x = self.linear(x)  
        x = torch.sigmoid(x)
        return x

In [None]:
Trainingset = dataset(file_path1="./reg_data/00/",file_path2="./reg_data/12/",force=0)
trainingloader = DataLoader(dataset=Trainingset,batch_size=4,shuffle=True)

In [None]:
netG_0_12 = UNet3D(1,1).to(device)
netG_12_0 = UNet3D(1,1).to(device)
netD_0 = Discriminator(1).to(device)
netD_12 = Discriminator(1).to(device)

In [None]:
criterion_GAN = nn.BCELoss()
criterion_cycle = nn.L1Loss()

optimizer_G = optim.Adam(
    itertools.chain(netG_0_12.parameters(), netG_12_0.parameters()), lr=0.000001, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(netD_0.parameters(), lr=0.0000001, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(netD_12.parameters(), lr=0.0000001, betas=(0.5, 0.999))

In [None]:
num_epochs = 400

In [None]:
for epoch in range(num_epochs):
    gen_loss_total = 0.0
    disc0_loss_total = 0.0
    disc12_loss_total = 0.0
    for inputs, targets, force in tqdm(trainingloader, desc=f"Epoch {epoch+1}/{num_epochs}"):

        real_A = inputs.to(device).unsqueeze(1).float()
        real_B = targets.to(device).unsqueeze(1).float()


        optimizer_G.zero_grad()


        fake_B = netG_0_12(real_A)
        recov_A = netG_12_0(fake_B)

 
        fake_A = netG_12_0(real_B)
        recov_B = netG_0_12(fake_A)

   
        pred_fake_A = netD_0(fake_A)
        loss_GAN_A2B = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))
        pred_fake_B = netD_12(fake_B)
        loss_GAN_B2A = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B))

        # Cycle consistency loss
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A + loss_cycle_B

        loss_G.backward()
        optimizer_G.step()

        
        optimizer_D_A.zero_grad()
        pred_real_A = netD_0(real_A)
        loss_D_real_A = criterion_GAN(pred_real_A, torch.ones_like(pred_real_A))
        pred_fake_A = netD_0(fake_A.detach())
        loss_D_fake_A = criterion_GAN(pred_fake_A, torch.zeros_like(pred_fake_A))
        loss_D_A = 0.5 * (loss_D_real_A + loss_D_fake_A)
        loss_D_A.backward()
        optimizer_D_A.step()

        
        optimizer_D_B.zero_grad()
        pred_real_B = netD_12(real_B)
        loss_D_real_B = criterion_GAN(pred_real_B, torch.ones_like(pred_real_B))
        pred_fake_B = netD_12(fake_B.detach())
        loss_D_fake_B = criterion_GAN(pred_fake_B, torch.zeros_like(pred_fake_B))
        loss_D_B = 0.5 * (loss_D_real_B + loss_D_fake_B)
        loss_D_B.backward()
        optimizer_D_B.step()

        gen_loss_total += loss_G.item() * inputs.size(0)
        disc0_loss_total += loss_D_A.item() * inputs.size(0)
        disc12_loss_total += loss_D_B.item() * inputs.size(0)

    gen_loss_total /= len(trainingloader.dataset)
    disc0_loss_total /= len(trainingloader.dataset)
    disc12_loss_total /= len(trainingloader.dataset)

    print(f"Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss_total:.4f}, DiscriminatorA Loss: {disc0_loss_total:.4f}, DiscriminatorB Loss: {disc12_loss_total:.4f}")

In [None]:
from DenseNet import *
densenet = DenseNet3D()

densenet.load_state_dict(torch.load('./saved_model/DenseNet/model400.pth'))
densenet = densenet.to(device)

for param in densenet.parameters():
    param.requires_grad = False


In [None]:
netG_0_12 = UNet3D(1,1).to(device)
netG_12_0 = UNet3D(1,1).to(device)
netD_0 = Discriminator(1).to(device)
netD_12 = Discriminator(1).to(device)

In [None]:
netG_0_12.load_state_dict(torch.load('./saved_model/cycleGAN/netG_0_12.pth'))
netG_12_0.load_state_dict(torch.load( './saved_model/cycleGAN/netG_12_0.pth'))
netD_0.load_state_dict(torch.load('./saved_model/cycleGAN/netD_0.pth'))
netD_12.load_state_dict(torch.load('./saved_model/cycleGAN/netD_12.pth'))

In [None]:
Trainingset = dataset(file_path1="./reg_data/00/",file_path2="./reg_data/12/",force=0,end_index=576)
trainingloader = DataLoader(dataset=Trainingset,batch_size=4,shuffle=True)


In [None]:
criterion_GAN = nn.BCELoss()
criterion_cycle = nn.L1Loss()
criterion_MSE = nn.MSELoss()

optimizer_G = optim.Adam(
    itertools.chain(netG_0_12.parameters(), netG_12_0.parameters()), lr=0.000001, betas=(0.5, 0.999)
)
optimizer_D_A = optim.Adam(netD_0.parameters(), lr=0.0000001, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(netD_12.parameters(), lr=0.0000001, betas=(0.5, 0.999))

In [None]:
num_epochs = 400

In [None]:
for epoch in range(num_epochs):
    gen_loss_total = 0.0
    disc0_loss_total = 0.0
    disc12_loss_total = 0.0
    for inputs, targets, force in tqdm(trainingloader, desc=f"Epoch {epoch+1}/{num_epochs}"):

        real_A = inputs.to(device).unsqueeze(1).float()
        real_B = targets.to(device).unsqueeze(1).float()


        optimizer_G.zero_grad()


        fake_B = netG_0_12(real_A)
        recov_A = netG_12_0(fake_B)

 
        fake_A = netG_12_0(real_B)
        recov_B = netG_0_12(fake_A)

   
        pred_fake_A = netD_0(fake_A)
        loss_GAN_A2B = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))
        pred_fake_B = netD_12(fake_B)
        loss_GAN_B2A = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B))

        # Cycle consistency loss
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        # feature loss
        loss_fea_A = criterion_cycle(fake_A,real_A)
        loss_fea_B = criterion_cycle(fake_B,real_B)

        
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A + loss_cycle_B + 25*loss_fea_A + 25*loss_fea_B

        loss_G.backward()
        optimizer_G.step()

        
        optimizer_D_A.zero_grad()
        pred_real_A = netD_0(real_A)
        loss_D_real_A = criterion_GAN(pred_real_A, torch.ones_like(pred_real_A))
        pred_fake_A = netD_0(fake_A.detach())
        loss_D_fake_A = criterion_GAN(pred_fake_A, torch.zeros_like(pred_fake_A))
        loss_D_A = 0.5 * (loss_D_real_A + loss_D_fake_A)
        # if loss_D_A.item()>0.01:
        loss_D_A.backward()
        optimizer_D_A.step()

        
        optimizer_D_B.zero_grad()
        pred_real_B = netD_12(real_B)
        loss_D_real_B = criterion_GAN(pred_real_B, torch.ones_like(pred_real_B))
        pred_fake_B = netD_12(fake_B.detach())
        loss_D_fake_B = criterion_GAN(pred_fake_B, torch.zeros_like(pred_fake_B))
        loss_D_B = 0.5 * (loss_D_real_B + loss_D_fake_B)
        # if loss_D_B.item()>0.01:
        loss_D_B.backward()
        optimizer_D_B.step()

        gen_loss_total += loss_G.item() * inputs.size(0)
        disc0_loss_total += loss_D_A.item() * inputs.size(0)
        disc12_loss_total += loss_D_B.item() * inputs.size(0)


    gen_loss_total /= len(trainingloader.dataset)
    disc0_loss_total /= len(trainingloader.dataset)
    disc12_loss_total /= len(trainingloader.dataset)

    print(f"Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss_total:.4f}, DiscriminatorA Loss: {disc0_loss_total:.4f}, DiscriminatorB Loss: {disc12_loss_total:.4f}")