In [7]:
import torch 
import torch.nn as nn 
import torch.multiprocessing as mp

import os
import time
import argparse
import shutil
import model_class
import prepare_data_nrrd_class
from config_class import Config
import preprocessor as Preprocess
from synthesis_class import Generate_sCT


In [1]:
# delete non-folders
import os
datapath = fr'/data/galaponav/dataset/newHN_CBCT_test'
patient_list = sorted(os.listdir(datapath))


for patient in patient_list:
    for files in os.listdir(os.path.join(datapath, patient)):
        if not os.path.isdir(os.path.join(datapath, patient, files)):
            os.remove(os.path.join(datapath, patient, files))
            print(f'Deleted: {os.path.join(datapath, patient, files)}')

Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_cycleGAN_p0002_TL30_FX.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/CBCT_registered.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_cycleGAN_p0002_TL20_FT_masked.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_DCNN_p0002_TL35_FX.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_cycleGAN_p0002_TL25_FX_masked.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/mask.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_cycleGAN_p0002_TL35_FT.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_cycleGAN_p0002_TL10_FT_masked.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_DCNN_p0002_TL10_FX.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_cycleGAN_p0002_TL10_FX_masked.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/sCT_cycleGAN_p0002_TL35_FX.nrrd
Deleted: /data/galaponav/dataset/newHN_CBCT_test/p0002/s

In [8]:
json_path = r"/home/galaponav/art/scripts/PhD/SCT_toolbox/tp_test_mp.json"
# "weights_path": "/data/galaponav/output/newHN_CBCT/DCNN/weights/epoch_18_UNET_forCap_best.pth",

time_start = time.time()
cfg = Config(json_path)

# Prepare and preprocess data
prepare_data = prepare_data_nrrd_class.prepare_dataset(cfg) 
prepare_data.run_sitk(10,10, DIR=cfg.DIR, eval=cfg.EVAL)
Preprocess.Preprocessor(cfg).preprocess()

# Initialize model and dataloader
model, device = model_class.Model(cfg).initialize_models()
num_inference = model_class.Model(cfg).enable_dropout(model)
dataloader = prepare_data.create_dataset()

# Multiprocessing
print('Start Multiprocessing')
num_processes = 10
model.share_memory()

# Create a list of process
processes = []
for rank in range(num_processes):
	sct_gen = Generate_sCT(cfg, model, dataloader, device)
	p = mp.Process(target=sct_gen.inference_loop(), name=f"process_{rank}")
	p.start()
	processes.append(p)
	print(f"Process {rank} started")

time_end = time.time()
print(f"Time taken: {time_end - time_start}")

Preparing images...
Removing top and bottom slices...
Segmenting Mask...
Preprocessing...
Deleting existing preprocessed data...


IndexError: index 70 is out of bounds for axis 0 with size 70

In [None]:
# Main function 
if __name__=='__main__': 
	# Set the number of processes and define the input and output data 
	
	X = torch.tensor([[1], [2], [3], [4]], dtype=torch.float32) 
	Y = torch.tensor([[2], [4], [6], [8]], dtype=torch.float32) 
	n_samples, n_features = X.shape 
	
	# Print the number of samples and features 
	print(f'#samples: {n_samples}, #features: {n_features}') 
	
	# Define the test input and the model input/output sizes 
	X_test = torch.tensor([5], dtype=torch.float32) 
	input_size = n_features 
	output_size = n_features 
	
	# Define the linear model and print its prediction on the test input before training 
	model = nn.Linear(input_size, output_size) 
	print(f'Prediction before training: f(5) = {model(X_test).item():.3f}') 
	
	# Number of processes 
	num_processes = 4
	# Share the model's memory to allow it to be accessed by multiple processes 
	model.share_memory() 

	# Create a list of processes and start each process with the train function 
	processes = [] 
	for rank in range(num_processes): 
		p = mp.Process(target=train, args=(model, X, Y,), name=f'Process-{rank}') 
		p.start() 
		processes.append(p) 
		print(f'Started {p.name}') 
	
	# Wait for all processes to finish 
	for p in processes: 
		p.join() 
		print(f'Finished {p.name}') 
	
	# Print the model's prediction on the test input after training 
	print(f'Prediction after training: f(5) = {model(X_test).item():.3f}')
