In [12]:
import os
import random
from sklearn.model_selection import StratifiedKFold
from collections import defaultdict, Counter

In [13]:
master_file = os.path.join(os.getcwd(), 'annotations', 'aidan_allclips_annotations.txt')
output_dir = os.path.join(os.getcwd(), 'annotations', '4fold_patient_balance')

In [14]:
def read_master_file(file_path):
    clips = []
    labels = []
    with open(file_path, 'r') as f:
        for line in f:
            clip, label = line.strip().split()
            clips.append(clip)
            labels.append(int(label))
    return clips, labels

def write_annotation_file(file_path, clips, labels):
    with open(file_path, 'w') as f:
        for clip, label in zip(clips, labels):
            f.write(f"{clip} {label}\n")

def balance_data(clips, labels):
    positive = [(c, l) for c, l in zip(clips, labels) if l == 1]
    negative = [(c, l) for c, l in zip(clips, labels) if l == 0]
    
    n_samples = min(len(positive), len(negative))
    balanced = random.sample(positive, n_samples) + random.sample(negative, n_samples)
    random.shuffle(balanced)
    return zip(*balanced)

def group_by_patient(clips, labels):
    patient_groups = defaultdict(lambda: {'clips': [], 'labels': []})
    for clip, label in zip(clips, labels):
        patient = clip.split('/')[1].split('_')[0]  # Extract patient number
        patient_groups[patient]['clips'].append(clip)
        patient_groups[patient]['labels'].append(label)
    return patient_groups

def create_cross_validation_files(master_file, output_dir):
    clips, labels = read_master_file(master_file)
    patient_groups = group_by_patient(clips, labels)

    num_folds = 4
    patients = list(patient_groups.keys())
    random.shuffle(patients)

    for fold in range(num_folds):
        fold_dir = os.path.join(output_dir, f"fold_{fold+1}")
        os.makedirs(fold_dir, exist_ok=True)

        # Split patients into train, val, and test
        n_patients = len(patients)
        test_patients = patients[int(0.8 * n_patients):]
        val_patients = patients[int(0.64 * n_patients):int(0.8 * n_patients)]
        train_patients = patients[:int(0.64 * n_patients)]

        # Collect clips and labels for each set
        train_clips, train_labels = [], []
        val_clips, val_labels = [], []
        test_clips, test_labels = [], []

        for patient in train_patients:
            train_clips.extend(patient_groups[patient]['clips'])
            train_labels.extend(patient_groups[patient]['labels'])
        
        for patient in val_patients:
            val_clips.extend(patient_groups[patient]['clips'])
            val_labels.extend(patient_groups[patient]['labels'])
        
        for patient in test_patients:
            test_clips.extend(patient_groups[patient]['clips'])
            test_labels.extend(patient_groups[patient]['labels'])

        # Balance data for each set
        train_clips, train_labels = balance_data(train_clips, train_labels)
        val_clips, val_labels = balance_data(val_clips, val_labels)
        test_clips, test_labels = balance_data(test_clips, test_labels)

        # Write annotation files
        write_annotation_file(os.path.join(fold_dir, "train.txt"), train_clips, train_labels)
        write_annotation_file(os.path.join(fold_dir, "val.txt"), val_clips, val_labels)
        write_annotation_file(os.path.join(fold_dir, "test.txt"), test_clips, test_labels)

        # Rotate patients for next fold
        patients = patients[n_patients//num_folds:] + patients[:n_patients//num_folds]

In [15]:
create_cross_validation_files(master_file, output_dir)

#### Ensure the files were created properly

In [20]:
def analyze_cross_validation(base_dir):
    for fold in range(1, 5):  # 4 folds
        fold_dir = os.path.join(base_dir, f"fold_{fold}")
        print(f"\nAnalyzing Fold {fold}:")
        
        set_counts = {}
        set_patients = {}
        set_labels = {}
        total_videos = 0
        
        for set_name in ['train', 'val', 'test']:
            file_path = os.path.join(fold_dir, f"{set_name}.txt")
            
            videos = []
            patients = set()
            labels = []
            
            with open(file_path, 'r') as f:
                for line in f:
                    video_path, label = line.strip().split()
                    videos.append(video_path)
                    labels.append(int(label))
                    patient = video_path.split('/')[1].split('_')[0]  # Extract patient number
                    patients.add(patient)
            
            set_counts[set_name] = len(videos)
            set_patients[set_name] = patients
            set_labels[set_name] = labels
            total_videos += len(videos)
        
        # Print counts, percentages, and label ratios
        print(f"Total videos in fold: {total_videos}")
        for set_name in ['train', 'val', 'test']:
            count = set_counts[set_name]
            percentage = (count / total_videos) * 100
            negative_count = set_labels[set_name].count(0)
            positive_count = set_labels[set_name].count(1)
            negative_ratio = (negative_count / count) * 100
            positive_ratio = (positive_count / count) * 100
            
            print(f"{set_name.capitalize()} set: {count} videos ({percentage:.2f}%)")
            print(f"  Negative (0) clips: {negative_count} ({negative_ratio:.2f}%)")
            print(f"  Positive (1) clips: {positive_count} ({positive_ratio:.2f}%)")
        
        # Print patients in each set
        for set_name in ['train', 'val', 'test']:
            print(f"\nPatients in {set_name} set: {', '.join(sorted(set_patients[set_name]))}")
        
        # Check for patient overlap
        all_patients = set()
        for patients in set_patients.values():
            all_patients.update(patients)
        
        patient_count = defaultdict(int)
        for patients in set_patients.values():
            for patient in patients:
                patient_count[patient] += 1
        
        overlapping_patients = [patient for patient, count in patient_count.items() if count > 1]
        
        if overlapping_patients:
            print("\nWARNING: The following patients appear in multiple sets:")
            for patient in overlapping_patients:
                print(f"Patient {patient} appears in:")
                for set_name, patients in set_patients.items():
                    if patient in patients:
                        print(f"  - {set_name} set")
        else:
            print("\nVerification successful: No patients appear in multiple sets within this fold.")

In [21]:
analyze_cross_validation(output_dir)


Analyzing Fold 1:
Total videos in fold: 2132
Train set: 1482 videos (69.51%)
  Negative (0) clips: 741 (50.00%)
  Positive (1) clips: 741 (50.00%)
Val set: 274 videos (12.85%)
  Negative (0) clips: 137 (50.00%)
  Positive (1) clips: 137 (50.00%)
Test set: 376 videos (17.64%)
  Negative (0) clips: 188 (50.00%)
  Positive (1) clips: 188 (50.00%)

Patients in train set: 02267738, 02268547, 05323733, 05352576, 05418761, 05447543, 05454991, 05463487, 05467817, 05486196, 05489744, 05497695, 05501184, 06338772, 06348578, 06394294

Patients in val set: 05109836, 05512494, 05513119, 06452950

Patients in test set: 00582992, 00913367, 05235825, 05514820, 06381028

Verification successful: No patients appear in multiple sets within this fold.

Analyzing Fold 2:
Total videos in fold: 2132
Train set: 1148 videos (53.85%)
  Negative (0) clips: 574 (50.00%)
  Positive (1) clips: 574 (50.00%)
Val set: 302 videos (14.17%)
  Negative (0) clips: 151 (50.00%)
  Positive (1) clips: 151 (50.00%)
Test set: 

### Write individual video annotation files for testing

In [19]:
from collections import defaultdict

# Function to read the annotations file and group entries
def group_annotations(file_path):
    groups = defaultdict(list)
    with open(file_path, 'r') as file:
        for line in file:
            parts = line.split()
            video_path = parts[0]
            label = parts[1]
            # Extract the XXXXXXXX identifier from the file name
            identifier = video_path.split('_')[2]
            groups[identifier].append(line.strip())
    return groups

# Function to write grouped entries to separate files
def write_grouped_annotations(groups, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for identifier, lines in groups.items():
        output_file_path = os.path.join(output_dir, f'{identifier}.txt')
        with open(output_file_path, 'w+') as output_file:
            for line in lines:
                output_file.write(line + '\n')

# Main function
def write_test_annotations(input_file, output_dir):
    groups = group_annotations(input_file)
    write_grouped_annotations(groups, output_dir)
    print(f"Annotations have been successfully grouped and written to '{output_dir}'.")


In [20]:
input_file = os.path.join(os.getcwd(), "annotations", "aidan_allclips_annotations.txt")  # Path to the input annotations file
output_dir = os.path.join(os.getcwd(), "annotations", "video_test_annotations")  # Directory to store the grouped files
write_test_annotations(input_file, output_dir)

Annotations have been successfully grouped and written to 'c:\Users\u251245\CVEpilepsy_remote\annotations\video_test_annotations'.
