In [None]:
# Description: Generate train and test folders with class subfolders for training DL models.
from spacr.io import generate_training_dataset

settings = {'src':'path or list of paths',
            'dataset_mode':'metadata',
            'test_split':0.1,
            'metadata_type_by':'col',
            'class_metadata':[['c1'],['c2']],
            'png_type':'cell_png',
            'nuclei_limit':True,
            'pathogen_limit':3,
            'uninfected':False,
            'size':None}

generate_training_dataset(settings)

In [None]:
# Description: Train a torch model to classify single object images
from spacr.deep_spacr import train_test_model

settings = {'src':'path', 
            'train':True,
            'test': False,
            'custom_model':False,
            'custom_model_path':None,
            'classes':['nc','pc'],
            'model_type':'maxvit_t',
            'optimizer_type':'adamw',
            'schedule':'reduce_lr_on_plateau', #reduce_lr_on_plateau, step_lr
            'loss_type':'focal_loss', #binary_cross_entropy_with_logits, #focal_loss
            'normalize':True,
            'image_size':224,
            'batch_size':64,
            'epochs':100,
            'val_split':0.1,
            'learning_rate':0.0001,
            'weight_decay':0.00001,
            'dropout_rate':0.1,
            'init_weights':True,
            'amsgrad':True,
            'use_checkpoint':True,
            'gradient_accumulation':True,
            'gradient_accumulation_steps':4,
            'intermedeate_save':True,
            'pin_memory':True,
            'n_jobs':30,
            'train_channels':['r','g','b'],
            'augment':False,
            'verbose':True}

train_test_model(settings)

In [None]:
# Description: Generate a tar file containing single object images.
from spacr.io import generate_dataset

settings = {'src':'path or list of paths',
           'file_metadata':None,
           'experiment':'tsg101_screen_plate1',
           'sample':None}

generate_dataset(settings)

In [None]:
# Description: Classify images in a tar dataset with a trained torch model.
from spacr.core import apply_model_to_tar

settings = {'tar_path':'path',
            'model_path':'path', 
            'file_type':'cell_png',
            'image_size':224,
            'batch_size':64,
            'normalize':True,
            'score_threshold':0.5,
            'n_jobs':30,
            'verbose':True}

result_df = spacr.core.apply_model_to_tar(settings)

In [None]:
# Description: Fix a regression model to estimate the effect size of gRNAs on cell scores.
# 

from spacr.ml import perform_regression
import pandas as pd
%matplotlib inline

settings = {'count_data':'path',
            'score_data':'path',
            'highlight':'string',
            'fraction_threshold':0.1,
            'dependent_variable': 'prediction_probability_class_1',
            'transform':'log',
            'agg_type':'median',
            'min_cell_count':25,
            'regression_type':'ols',
            'random_row_column_effects':False,
            'plate':None,
            'cov_type':None,
            'alpha':0.8,
            'nc':'c1',
            'pc':'c2',
            'other':'c3'}

coef_df = perform_regression(settings)

In [None]:
import torch

def model_fusion(model_paths, save_path, model_class, device='cpu'):
    """
    Fuses an arbitrary number of models by averaging their weights and saves the fused model.

    Parameters:
        model_paths (list): List of file paths to the models to be fused (e.g., PyTorch .pth files).
        save_path (str): File path to save the fused model.
        model_class (class): The class of the model (e.g., your MaxViT class).
        device (str): Device to load the models onto ('cpu' or 'cuda').

    Returns:
        fused_model: The fused model with averaged weights.
    """
    # Initialize a list to hold the state dictionaries
    state_dicts = []

    # Load all model weights
    for path in model_paths:
        print(f"Loading model from: {path}")
        state_dict = torch.load(path, map_location=device)
        state_dicts.append(state_dict)

    # Ensure all models have the same architecture
    if not all(state_dicts[0].keys() == sd.keys() for sd in state_dicts):
        raise ValueError("All models must have the same architecture and state_dict keys.")

    # Initialize a new model to hold the averaged weights
    fused_model = model_class().to(device)
    fused_state_dict = fused_model.state_dict()

    # Iterate over all keys in the state dict and average the weights
    for key in fused_state_dict.keys():
        # Average the weights for this key across all models
        fused_state_dict[key] = torch.stack([sd[key].float() for sd in state_dicts]).mean(dim=0)

    # Load the averaged weights into the fused model
    fused_model.load_state_dict(fused_state_dict)

    # Save the fused model
    torch.save(fused_model.state_dict(), save_path)
    print(f"Fused model saved to: {save_path}")

    return fused_model

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

def train_student_model(teacher_models, student_model, dataloader, optimizer, loss_fn, device='cpu'):
    """
    Trains a student model using knowledge distillation from teacher models.

    Parameters:
        teacher_models (list): List of pre-trained teacher models.
        student_model (torch.nn.Module): The student model to be trained.
        dataloader (DataLoader): DataLoader for the training data.
        optimizer (torch.optim.Optimizer): Optimizer for the student model.
        loss_fn (function): Loss function (e.g., KL divergence or cross-entropy).
        device (str): Device to use ('cpu' or 'cuda').

    Returns:
        student_model: Trained student model.
    """
    student_model.to(device)
    for model in teacher_models:
        model.to(device)
        model.eval()

    student_model.train()
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        # Get predictions from all teacher models
        teacher_outputs = [F.softmax(model(images), dim=1) for model in teacher_models]

        # Average the softmax predictions to create soft labels
        soft_labels = torch.stack(teacher_outputs).mean(dim=0)

        # Forward pass through the student model
        student_outputs = student_model(images)

        # Compute knowledge distillation loss
        loss = loss_fn(student_outputs, soft_labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return student_model