In [None]:
import torch
import cv2
import torch.nn as nn
import numpy as np
from torchsummary import summary
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from dataloader import dataset
from matplotlib import pyplot as plt



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
from pytorch3dunet.unet3d import *
from pytorch3dunet.unet3d.buildingblocks import *
from pytorch3dunet.unet3d.model import *


In [None]:
u3d = UNet3D(in_channels = 1, out_channels = 1)

In [None]:
Trainingset = dataset(file_path1="./reg_data/00/",file_path2="./reg_data/04/",force=0,start_index=144,end_index=720)
trainingloader = DataLoader(dataset=Trainingset,batch_size=8,shuffle=True)

Testingset = dataset(file_path1="./reg_data/00/",file_path2="./reg_data/04/",force=0,start_index=0,end_index=144)
testloader = DataLoader(dataset=Testingset,batch_size=8,shuffle=True)

In [None]:
def eval_Unet(model, val_loader, criterion):
    model.eval()
    loss_total = 0.0
    with torch.no_grad():
        for inputs, targets,_ in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.unsqueeze(1).float()
            targets = targets.unsqueeze(1).float()

            generated_images = model(inputs)
            loss = criterion(generated_images,targets)
            loss_total += loss.item() * inputs.size(0)

        loss_total /= len(val_loader.dataset)
    return loss_total

In [None]:
def train_Unet(model, train_loader, val_loader, num_epochs, device, lr=0.0002, beta1=0.5, beta2=0.999):
    
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(beta1, beta2))
    # criterion = nn.L1Loss()
    criterion = nn.MSELoss()
    #criterion = nn.SmoothL1Loss()


    model.to(device)
    train_losses = []
    val_losses = []
    min_val = 1.0
    for epoch in range(num_epochs):
        loss_total = 0.0
        model.train()
        for inputs, targets,_ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):

            inputs, targets = inputs.to(device), targets.to(device)
            inputs = inputs.unsqueeze(1).float()
            targets = targets.unsqueeze(1).float()

            optimizer.zero_grad()

            generated_images = model(inputs)

            loss = criterion(generated_images,targets)
            

            loss.backward()
            optimizer.step()

            loss_total += loss.item() * inputs.size(0)
        loss_total /= len(train_loader.dataset)
        train_losses.append(loss_total)

        val_loss = eval_Unet(model,val_loader,criterion)
        val_losses.append(val_loss)

        print(f"Epoch [{epoch+1}/{num_epochs}], Training_Loss: {loss_total:.4f}, Val_Loss: {val_loss:.4f}")

        if val_loss < min_val:
            min_val = val_loss
            if val_loss < 0.07:
                torch.save(model.state_dict(), f'./saved_model/old_unet/04_{epoch}_{val_loss:.4f}_lr1e-4.pth')
    
    torch.save(model.state_dict(), f'./saved_model/Unet/04_final_lr1e-4.pth')
    print('Finished Training')

    return train_losses, val_losses


In [None]:
Training_losses,  Val_losses = train_Unet(u3d, trainingloader,testloader, 400, device, lr=0.0001, beta1=0.5, beta2=0.999)