# Data set generation

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import os
import torch
import torch.nn.functional as F

In [None]:
# load the ids of available patients
available = pd.read_csv("/home/nicke/MasterThesis/available_US_probands.csv")
available = available.drop('Unnamed: 0', axis=1)
available = available[available['Anatomy'] != 'BACKGROUND']
available

In [None]:
path_to_data = '/share/data_ultraschall/compressions'

In [None]:
id_list = (available.Id.values).astype(str)

In [None]:
# Check for available landmarks in the ids.
landmarks = pd.read_csv('/home/nicke/MasterThesis/landmarks.csv')

# Get the idx of the compression starts
landmarks = landmarks[(landmarks['Start Frames'] != '[]') & (landmarks['End Frames'] != '[]')& (landmarks['End Frames'] != 'DNC')]
landmarks

In [None]:
# merge it all into one DF
data = landmarks.merge(available, on='Id')

In [None]:
id_list = (data.Id.values).astype(str)

In [None]:
def load_image_and_seg(path, image):
    """
    function to load images and fitting segmentation
    
    path: str. path to an ID directory
    image: str. Image number to be loaded
    
    return: np.ndarray. normalized gray scale image with segmentation
    """
    
    # load image with PIL and normalize
    img = np.array(Image.open(os.path.join(path,'frames', image))) / 255
    #print(os.path.join(path,'frames', image))
    #print(os.path.join(path,'segmentations', '1', image))
    
    # load seg and normalize over the labels
    seg =np.array(Image.open(os.path.join(path,'segmentations', '1', image))) / 200
    
    return img, seg

def get_image_seg_pairs(prob_id):
    """
    function to load image pairs from one ID

    prob_id: str denoting the ID of a folder to be loaded
    return: np.ndarray image pairs ; np.ndarray segmentation pairs 
    """
    
    
    print(f'Working on id: {prob_id}')
    # generate sorted list of image names.
    # common structure: 00xx.png where XX is a number between 00 and the max number of recorded frames
    all_frames = sorted(os.listdir(os.path.join(path_to_data,prob_id,'frames')))

    # get the first frame bevore the comrpession starts and the last available frame after the comrepssion
    first_available_frame = np.fromstring(landmarks[landmarks['Id']== int(prob_id)]['Start Frames'].iat[0].strip(']['), sep=',', dtype=int)[0]
    last_available_frame = np.fromstring(landmarks[landmarks['Id']== int(prob_id)]['End Frames'].iat[0].strip(']['), sep=',', dtype=int)
    
    # generate image pairs by names 
    file_pairs = []
    for j,f_frame in enumerate(first_available_frame):
        i = 2
        while f_frame + i < last_available_frame[j]:
            file_pairs.append([all_frames[f_frame], all_frames[f_frame+i]])
            i = i+2
            if len(file_pairs) > 6:
                break
    
    frame_pairs = []
    seg_pairs = []
    
    # load the seg and frame for fixed and moving
    for fixed_file, moving_file in file_pairs:
        
        fixed, fixed_seg = load_image_and_seg(os.path.join(path_to_data, prob_id), fixed_file)
        moving, moving_seg = load_image_and_seg(os.path.join(path_to_data, prob_id), moving_file)
        
        if fixed_seg.max() == 0:
            continue
        if moving_seg.max() == 0:
            continue
        
        
        # and store them together
        frame_pairs.append([fixed,moving])
        seg_pairs.append([fixed_seg,moving_seg])
    
    return np.array(frame_pairs), np.array(seg_pairs)    
    

In [None]:
# go over all usefule ID as in list
frames = []
segs = []
ids = []
for prob_id in id_list:

    # select frame and segmentation pairs for every ID
    frame_pairs, seg_pairs = get_image_seg_pairs(prob_id)
    
    for pair in frame_pairs:
        frames.append(pair)
    for pair in seg_pairs:
        segs.append(pair)

In [None]:
# convert to torch
all_frames = torch.from_numpy(np.array(frames))
all_segs = torch.from_numpy(np.array(segs))

In [None]:
assert all_frames.shape == all_segs.shape

In [None]:
all_frames.shape

In [None]:
# split into train and eval
train_idx = np.random.choice(len(all_frames), size=int(len(all_frames) * 0.9), replace=False)
test_idx = np.arange(0,len(all_frames))
for idx in train_idx:
    test_idx = test_idx[test_idx != idx]

In [None]:
train_idx

In [None]:
print(train_idx.shape)
test_idx.shape

In [None]:
torch.manual_seed(42)
frames = all_frames[torch.from_numpy(train_idx)]
segs = all_segs[torch.from_numpy(train_idx)]

test_frames = all_frames[torch.from_numpy(test_idx)]
test_segs = all_segs[torch.from_numpy(test_idx)]

### Store Torch data

In [None]:
torch.save(test_frames, "/share/data_ultraschall/nicke_ma/data/test_frames_oneFixed_multipleMoving_dist2.pth")
torch.save(test_segs, "/share/data_ultraschall/nicke_ma/data/test_segs_oneFixed_multipleMoving_dist2.pth")

In [None]:
torch.save(frames, "/share/data_ultraschall/nicke_ma/data/frames_oneFixed_multipleMoving_dist2.pth")
torch.save(segs, "/share/data_ultraschall/nicke_ma/data/segs_oneFixed_multipleMoving_dist2.pth")