# Model application
This notebook will perform the following tasks:
1. Split the movies into training set, validation set and test set.
2. Apply the classification model to the movie dataset.

# Part 1: Train-validation-test split

In [1]:
import os
import pandas as pd
import pickle as pkl

IMG_DIR = "/home/xavier/Documents/dataset/Welch/trainingset2/trainingset2"
OUT_DIR = "/home/xavier/Documents/dataset/Welch/classification-v2024/classification_models/240430-001/movie_classification"

os.makedirs(OUT_DIR, exist_ok=True)

classify_df, name_df = [], []
phenotype_dict = {}
'''Build dataset'''
for strain in os.listdir(IMG_DIR):
    run_id = int(strain[-4:])
    for scope in os.listdir(os.path.join(IMG_DIR, strain)):
        scope_id = int(scope[-2:])
        directory = f"{strain}/{scope}"
        name_df.append((run_id, scope_id, directory))

name_df = pd.DataFrame(name_df, columns=['run_id', 'scope_id', 'directory'])
name_df

Unnamed: 0,run_id,scope_id,directory
0,636,38,CS5_78_0425_1%agar_Run0636/Scope38
1,636,37,CS5_78_0425_1%agar_Run0636/Scope37
2,636,39,CS5_78_0425_1%agar_Run0636/Scope39
3,672,3,CS6_27_1253_1%agar_Run0672/Scope03
4,672,2,CS6_27_1253_1%agar_Run0672/Scope02
...,...,...,...
932,287,31,CS1_55_10536_1%agar_Run0287/Scope31
933,287,32,CS1_55_10536_1%agar_Run0287/Scope32
934,526,30,CS4_81_5257_1%agar_Run0526/Scope30
935,526,29,CS4_81_5257_1%agar_Run0526/Scope29


## Split the dataset
We choose 100 movies each for validation set and test set. The movies have distinct run_id in each set.

In [2]:
VALIDATION_SIZE = 100

unique_run_ids = name_df['run_id'].drop_duplicates().sample(n=VALIDATION_SIZE, random_state=70)
test_set = name_df[name_df['run_id'].isin(unique_run_ids)].groupby('run_id').apply(
    lambda x: x.sample(1, random_state=405)).reset_index(drop=True)

mask = pd.merge(name_df, test_set, on=['run_id', 'scope_id'], how='left', indicator=True)
remaining_set = mask[mask['_merge'] == 'left_only'].drop(columns=['_merge'])

unique_run_ids = remaining_set['run_id'].drop_duplicates().sample(n=VALIDATION_SIZE, random_state=44)
validation_set = remaining_set[remaining_set['run_id'].isin(unique_run_ids)].groupby('run_id').apply(
    lambda x: x.sample(1, random_state=1622)).reset_index(drop=True)

mask = pd.merge(remaining_set, validation_set, on=['run_id', 'scope_id'], how='left', indicator=True)
training_set = mask[mask['_merge'] == 'left_only'].drop(columns=['_merge'])

test_set['Category'] = 'Test'
validation_set['Category'] = 'Validation'
training_set['Category'] = 'Training'
full_dataset = pd.concat([test_set, validation_set, training_set])
full_dataset = full_dataset.drop_duplicates(subset=['run_id', 'scope_id'])
full_dataset = full_dataset[['run_id', 'scope_id', 'Category']]
full_dataset = pd.merge(full_dataset, name_df, on=['run_id', 'scope_id'], how='left')
full_dataset.to_csv(f'{OUT_DIR}/generator_full_dataset.csv', index=False)

full_dataset

In [3]:
ans = 0
for index, row in full_dataset.iterrows():
    ans += len(os.listdir(os.path.join(IMG_DIR, row['directory'])))
print("There are " + str(ans) + " samples in total.")

# Part 2: Classify images
We use the pre-trained inception-V3 network and center crop of images to perform classification.

In [21]:
import cv2
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import dnnlib
import pickle as pkl
import torch
from torch.utils.data import DataLoader, Dataset


def resize_crop(img_dir, resize_by=1., resolution=512, brightness_norm=False, brightness_mean=107.2, use_rgb=True):
    img = cv2.imread(img_dir, cv2.IMREAD_UNCHANGED)
    if img.dtype != np.uint8:
        img = np.uint8(img / 256)
    img_shape = img.shape
    resize_shape = np.array([img_shape[1] * resize_by, img_shape[0] * resize_by], dtype=int)
    if resize_by != 1:
        img = cv2.resize(img, resize_shape, cv2.INTER_LANCZOS4)

    if use_rgb and len(img.shape) == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    new_img = img[(resize_shape[1] - resolution) // 2:(resize_shape[1] + resolution) // 2,
              (resize_shape[0] - resolution) // 2:(resize_shape[0] + resolution) // 2]
    if brightness_norm:
        obj_v = np.mean(new_img)
        value = brightness_mean - obj_v
        new_img = cv2.add(new_img, value)
    return new_img


class TestDataset(Dataset):
    def __init__(self, root_dir, img_names):
        self.root_dir = root_dir
        self.img_names = img_names

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img = resize_crop(os.path.join(self.root_dir, self.img_names[idx]))
        img = torch.tensor(img[:, :, :]).permute(2, 0, 1)
        return img


MODEL_DIR = "/home/xavier/Documents/dataset/Welch/classification-v2024/classification_models/240430-001"
OUT_DIR = "/home/xavier/Documents/dataset/Welch/classification-v2024/classification_models/240430-001/movie_classification"
best_epoch = 560
BATCH_SIZE = 500
TEST_WORKERS = 8
checkpoint_interval = 10
device = torch.device("cuda")

In [5]:
TEMP_DIR = f'{OUT_DIR}/temporary_pred_labels.npy'
with dnnlib.util.open_url(os.path.join(MODEL_DIR, f"model_epoch_{best_epoch:03d}.pkl"), verbose=True) as f:
    detector = pkl.load(f)
detector.to(device)
detector.eval()

all_img_names = []
categories = []
for index, row in full_dataset.iterrows():
    img_folder = f"{IMG_DIR}/{row['directory']}"
    img_names = [f"{row['directory']}/{img_name}" for img_name in os.listdir(img_folder)]
    all_img_names.extend(img_names)
    categories.extend([row['Category']] * len(img_names))

test_dataset = TestDataset(root_dir=IMG_DIR, img_names=all_img_names)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=TEST_WORKERS)

if os.path.exists(TEMP_DIR):
    pred_labels = list(np.load(TEMP_DIR))
    start_batch = len(pred_labels) // BATCH_SIZE
else:
    pred_labels = []
    start_batch = 0

with torch.no_grad():
    for i, images in enumerate(tqdm(test_dataloader, desc="Processing images")):
        if i < start_batch:
            continue
        images = images.to(device)
        outputs = detector(images)
        _, predicted = torch.max(outputs, 1)
        pred_labels.extend(predicted.cpu().numpy())

        if (i + 1) % checkpoint_interval == 0:
            np.save(TEMP_DIR, np.array(pred_labels))

final_labels = np.array(pred_labels)

final_df = pd.DataFrame({
    'img_name': all_img_names,
    'label': final_labels,
    'Category': categories
})

final_df.to_csv(f'{OUT_DIR}/classified_images.csv', index=False)

In [22]:
final_df = pd.read_csv(f"{OUT_DIR}/classified_images.csv")
label_dict = pkl.load(open(f"{MODEL_DIR}/label_dict.pkl", "rb"))
inv_label_dict = {v: k for k, v in label_dict.items()}
# final_df['label'] = final_df['label'].map(inv_label_dict)
label_counts = final_df.groupby(['label', 'Category']).size().unstack(fill_value=0)

total_row = label_counts[['Training', 'Validation', 'Test']].sum().astype(int)
total_row.name = 'Total'

total_df = pd.DataFrame(total_row).transpose()
label_counts = pd.concat([label_counts, total_df])
label_counts['Total'] = label_counts[['Training', 'Validation', 'Test']].sum(axis=1).astype(int)
label_counts.index = label_counts.index.map(inv_label_dict)
label_counts

Category,Test,Training,Validation,Total
Blank,18192,104810,11990,134992
Branched,1097,8868,586,10551
Clusters,5771,30835,6884,43490
Dense,7627,60905,7682,76214
Incomplete,23979,181305,21813,227097
LWT,7762,61657,11113,80532
Large,2965,15941,3318,22224
Long,3201,36308,5868,45377
Malformed,6780,48926,7497,63203
Small,8282,45295,5756,59333


# Check sample images

In [8]:
import matplotlib.pyplot as plt
from PIL import Image

# Select one image name for each label
sample_images = final_df.groupby('label').first()['img_name']

# Setup the plot
fig, axes = plt.subplots(nrows=1, ncols=len(sample_images), figsize=(15, 5))

# Check if there is only one label, and handle the axis accordingly
if len(sample_images) == 1:
    axes = [axes]

# Load and display each image
for ax, (label, img_path) in zip(axes, sample_images.items()):
    img = Image.open(os.path.join(IMG_DIR, img_path))
    ax.imshow(img, cmap="gray")
    ax.set_title(label)
    ax.axis('off')  # Hide the axes ticks

plt.tight_layout()
plt.show()

# Prepare dictionary for training

In [24]:
training_dict = {}
for phenotype, group in final_df[final_df['Category'] == 'Training'].groupby('label'):
    names = group['img_name'].values
    training_dict[phenotype] = [f"{name}" for name in names]
pkl.dump(training_dict, open(os.path.join(OUT_DIR, "training_dict.pkl"), "wb"))

# training_dict_1class = {0: []}
# for phenotype, group in final_df[final_df['Category'] == 'Training'].groupby('label'):
#     names = group['img_name'].values
#     training_dict_1class[0].extend([f"{name}" for name in names])
# pkl.dump(training_dict_1class, open(os.path.join(OUT_DIR, "training_dict_1class.pkl"), "wb"))

validation_set = {}
for phenotype, group in final_df[final_df['Category'] == 'Validation'].groupby('label'):
    names = group['img_name'].values
    validation_set[phenotype] = [f"{name}" for name in names]
pkl.dump(validation_set, open(os.path.join(OUT_DIR, "validation_dict.pkl"), "wb"))

test_dict = {}
for phenotype, group in final_df[final_df['Category'] == 'Test'].groupby('label'):
    names = group['img_name'].values
    test_dict[phenotype] = [f"{name}" for name in names]
pkl.dump(test_dict, open(os.path.join(OUT_DIR, "test_dict.pkl"), "wb"))

# Extra
## Sample 100 images in each case to form a smaller validation set

In [2]:
import random

validation_set = pkl.load(open(os.path.join(OUT_DIR, "validation_dict.pkl"), "rb"))
for item in validation_set:
    validation_set[item] = random.sample(validation_set[item], 5)

pkl.dump(validation_set, open(os.path.join(OUT_DIR, "validation_tiny_dict.pkl"), "wb"))

## Merge validation set and test set

In [3]:
validation_set = pkl.load(open(os.path.join(OUT_DIR, "validation_dict.pkl"), "rb"))
test_set = pkl.load(open(os.path.join(OUT_DIR, "test_dict.pkl"), "rb"))
for phenotype in validation_set:
    validation_set[phenotype].extend(test_set[phenotype])
pkl.dump(validation_set, open(os.path.join(OUT_DIR, "merged_test_dict.pkl"), "wb"))

## Collect Images to a new folder or .zip

In [3]:
IMG_DIR = "/home/xavier/Documents/dataset/Welch/trainingset2/trainingset2"
OUT_DIR = "/media/xavier/Storage/feature_extraction/Welch-validation.zip"
DICT_DIR = "/home/xavier/Documents/dataset/Welch/classification-v2024/classification_models/240430-001/movie_classification/validation_dict.pkl"

import os
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import psutil
import time
import pickle as pkl


# Function to monitor memory and pause if threshold is exceeded
def monitor_memory(threshold_mb=20971):  # Set threshold to 30 GB
    while True:
        usage_mb = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
        if usage_mb < threshold_mb:
            return
        print(f"Current memory usage is {usage_mb:.2f} MB, pausing until memory drops below threshold...")
        time.sleep(10)


# Function to process a single file
def process_file(file_path, base_folder, zipf, lock):
    arcname = file_path[len(base_folder.rstrip(os.sep)) + 1:]
    with lock:
        zipf.write(file_path, arcname=arcname)


# Function to process files in batches with indexed progress bars
def process_files_in_batches(files, base_folder, zipf, lock, batch_size=2000):
    # Process files in batches
    total_batches = (len(files) + batch_size - 1) // batch_size
    for i in range(0, len(files), batch_size):
        batch_index = i // batch_size + 1
        batch = files[i:i + batch_size]
        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = {executor.submit(process_file, file, base_folder, zipf, lock): file for file in batch}
            with tqdm(as_completed(futures), total=len(futures),
                      desc=f"Batch {batch_index}/{total_batches} - Zipping files", miniters=1) as progress:
                for future in progress:
                    future.result()  # Wait for the future to complete
        monitor_memory()  # Check memory after each batch


# Main function to zip directory
def zip_directory(base_folder, dictionary_dir, zip_filename):
    files_to_zip = []
    sampled_dict = pkl.load(open(dictionary_dir, "rb"))
    for group in sampled_dict:
        files_to_zip.extend([os.path.join(base_folder, img_name) for img_name in sampled_dict[group]])

    with zipfile.ZipFile(zip_filename, 'w', compression=zipfile.ZIP_STORED) as zipf:
        from threading import Lock
        lock = Lock()  # Lock for thread-safe writing to the zipfile
        process_files_in_batches(files_to_zip, base_folder, zipf, lock)


zip_directory(IMG_DIR, DICT_DIR, OUT_DIR)

Batch 1/73 - Zipping files: 100%|██████████| 2000/2000 [00:01<00:00, 1386.38it/s]
Batch 2/73 - Zipping files: 100%|██████████| 2000/2000 [00:01<00:00, 1409.37it/s]
Batch 3/73 - Zipping files: 100%|██████████| 2000/2000 [00:01<00:00, 1213.96it/s]
Batch 4/73 - Zipping files: 100%|██████████| 2000/2000 [00:01<00:00, 1010.65it/s]
Batch 5/73 - Zipping files: 100%|██████████| 2000/2000 [00:01<00:00, 1277.49it/s]
Batch 6/73 - Zipping files: 100%|██████████| 2000/2000 [00:02<00:00, 997.87it/s] 
Batch 7/73 - Zipping files: 100%|██████████| 2000/2000 [00:02<00:00, 681.37it/s]
Batch 8/73 - Zipping files: 100%|██████████| 2000/2000 [00:02<00:00, 825.51it/s] 
Batch 9/73 - Zipping files: 100%|██████████| 2000/2000 [00:02<00:00, 844.57it/s] 
Batch 10/73 - Zipping files: 100%|██████████| 2000/2000 [00:02<00:00, 746.62it/s] 
Batch 11/73 - Zipping files: 100%|██████████| 2000/2000 [00:03<00:00, 515.49it/s]
Batch 12/73 - Zipping files: 100%|██████████| 2000/2000 [00:03<00:00, 575.19it/s]
Batch 13/73 - Zi