In [11]:
import torch
import torch.nn.functional as F
import torch.optim as optim

def model_knowledge_transfer(settings):
    """
    High-level wrapper that:
      1) Generates a training DataLoader from 'src' using generate_loaders.
      2) Performs knowledge distillation from multiple teacher models onto a 
         single student, using the new labeled data for both supervised CE 
         and distillation loss.
      3) Saves the distilled student to `student_save_path`.

    Args:
        teacher_paths (list[str]): Paths to teacher checkpoints (TorchModel or dict).
        src (str): Source directory for your new labeled data (with subfolders 'train'/'test').
        student_save_path (str): Where to save the final student TorchModel.
        device (str): 'cpu' or 'cuda'.
        student_model_name, pretrained, dropout_rate, use_checkpoint: 
            TorchModel init arguments for the student.
        alpha, temperature, lr, epochs: Distillation hyperparams.
        image_size, batch_size, classes, etc.: Passed to generate_loaders.
    """

    from spacr.io import generate_loaders
    
    def _knowledge_transfer(
        teacher_paths,
        student_save_path,
        data_loader,            # DataLoader for (images, labels)
        device='cpu',
        student_model_name='maxvit_t',
        pretrained=True,
        dropout_rate=None,
        use_checkpoint=False,
        alpha=0.5,
        temperature=2.0,
        lr=1e-4,
        epochs=10
    ):
        """
        Performs multi-teacher knowledge distillation on a new labeled dataset,
        producing a single student TorchModel that combines the teachers' knowledge
        plus the labeled data.

        Args:
            teacher_paths (list[str]): Paths to teacher models (TorchModel or dict).
            student_save_path (str): Destination path to save the final student.
            data_loader (DataLoader): Yields (images, labels) from the new dataset.
            device (str): 'cpu' or 'cuda'.
            student_model_name (str): Architecture name for the student TorchModel.
            pretrained (bool): If the student should be initialized as pretrained.
            dropout_rate (float): If needed by your TorchModel init.
            use_checkpoint (bool): If needed by your TorchModel init.
            alpha (float): Weight for real-label CE vs. distillation loss.
            temperature (float): Distillation temperature.
            lr (float): Learning rate for the student.
            epochs (int): Number of training epochs.

        Returns:
            TorchModel: The final, trained student model.
        """
        from spacr.utils import TorchModel  # adjust if needed

        # Adjust filename to reflect KD if desired
        import os
        base, ext = os.path.splitext(student_save_path)
        if not ext:
            ext = '.pth'
        student_save_path = f"{base}_KD{ext}"

        # 1) Load teacher models
        teachers = []
        print("Loading teacher models:")
        for path in teacher_paths:
            print(f"  Loading teacher: {path}")
            ckpt = torch.load(path, map_location=device)
            if isinstance(ckpt, TorchModel):
                teacher = ckpt.to(device)
            elif isinstance(ckpt, dict):
                from spacr.utils import TorchModel
                teacher = TorchModel(
                    model_name=ckpt.get('model_name', student_model_name),
                    pretrained=ckpt.get('pretrained', pretrained),
                    dropout_rate=ckpt.get('dropout_rate', dropout_rate),
                    use_checkpoint=ckpt.get('use_checkpoint', use_checkpoint)
                ).to(device)
                teacher.load_state_dict(ckpt['model'])
            else:
                raise ValueError(f"Unsupported checkpoint type at {path} (must be TorchModel or dict).")

            teacher.eval()  # for consistent BN, dropout
            teachers.append(teacher)

        # 2) Initialize the student TorchModel
        student_model = TorchModel(
            model_name=student_model_name,
            pretrained=pretrained,
            dropout_rate=dropout_rate,
            use_checkpoint=use_checkpoint
        ).to(device)

        # 3) Optimizer
        optimizer = optim.Adam(student_model.parameters(), lr=lr)

        # 4) Training loop
        for epoch in range(epochs):
            student_model.train()
            running_loss = 0.0

            for images, labels in data_loader:
                images = images.to(device)
                labels = labels.to(device)

                # Forward pass: student
                logits_s = student_model(images)
                logits_s_temp = logits_s / temperature

                # Distillation: get average teacher probabilities
                with torch.no_grad():
                    teacher_probs_list = []
                    for tm in teachers:
                        logits_t = tm(images) / temperature
                        teacher_probs_list.append(F.softmax(logits_t, dim=1))
                    teacher_probs_ensemble = torch.mean(torch.stack(teacher_probs_list), dim=0)

                # Student distribution
                student_log_probs = F.log_softmax(logits_s_temp, dim=1)

                # Distillation loss (KLDiv)
                loss_distill = F.kl_div(
                    student_log_probs,
                    teacher_probs_ensemble,
                    reduction='batchmean'
                ) * (temperature ** 2)

                # Supervised loss (CE with real labels)
                loss_ce = F.cross_entropy(logits_s, labels)

                # Weighted sum
                loss = alpha * loss_ce + (1 - alpha) * loss_distill

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            avg_loss = running_loss / len(data_loader)
            print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")

        # 5) Save final student
        torch.save(student_model, student_save_path)
        print(f"Knowledge-distilled student saved to: {student_save_path}")

        return student_model
    
    # 1) Generate DataLoader
    print("Generating training DataLoader...")
    train_loaders, val_loaders, train_fig = generate_loaders(
        src=src,
        mode='train',
        image_size=image_size,
        batch_size=batch_size,
        classes=classes,
        n_jobs=n_jobs,
        validation_split=validation_split,
        pin_memory=pin_memory,
        normalize=normalize,
        channels=channels,
        augment=augment,
        verbose=verbose
    )

    # If validation_split=0, train_loaders is a single DataLoader
    # If >0, it's a DataLoader for train, and val_loaders is for validation
    if validation_split > 0.0:
        print("Note: We'll only use the train DataLoader for knowledge distillation, ignoring val_loaders.")
        train_loader = train_loaders
    else:
        train_loader = train_loaders  # or whichever you used

    # 2) Perform knowledge distillation
    #from .my_code import model_knowledge_transfer  # or your actual import
    distilled_student = _knowledge_transfer(
        teacher_paths=teacher_paths,
        student_save_path=student_save_path,
        data_loader=train_loader,
        device=device,
        student_model_name=student_model_name,
        pretrained=pretrained,
        dropout_rate=dropout_rate,
        use_checkpoint=use_checkpoint,
        alpha=alpha,
        temperature=temperature,
        lr=lr,
        epochs=epochs
    )

    print("Distillation complete. Student model returned.")
    return distilled_student, train_fig

teacher_paths = ['/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240810_132824/plate1/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth',
              '/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240824_072829/plate2/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth',
              '/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240825_094106/plate3/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth',
              '/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240826_140251/plate4/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth']

src = '/nas_mnt/carruthers/Einar/tsg101_screen/hits_20250108_181547/plate2/datasets/training'

settings = {
    'teacher_paths':['list of paths'],
    'src':'path',
    'student_save_path':'save_path',
    'device':'cpu',
    'student_model_name':'maxvit_t',
    'pretrained':True,
    'dropout_rate':None,
    'use_checkpoint':False,
    'alpha':0.5,
    'temperature':2.0,
    'lr':1e-4,
    'epochs':10,
    'image_size':224,
    'batch_size':32,
    'classes':('nc','pc'),
    'n_jobs':None,
    'validation_split':0.0,
    'pin_memory':False,
    'normalize':False,
    'channels':('r','g','b'),
    'augment':False,
    'verbose':False}

model_knowledge_transfer(settings)

In [16]:
import os
import torch
import torch.nn.functional as F
import torch.optim as optim

def model_knowledge_transfer(settings):
    """
    High-level wrapper that:
      1) Generates a training DataLoader from 'src' using generate_loaders.
      2) Performs knowledge distillation from multiple teacher models onto a 
         single student TorchModel, using the new labeled data for both 
         supervised CE and distillation loss.
      3) Saves the distilled student to `student_save_path`.

    Args:
        settings (dict): A dictionary containing all the necessary parameters:
            {
              "teacher_paths": list of str,  # Paths to teacher checkpoints
              "src": str,                    # Source directory for new labeled data
              "student_save_path": str,      # Where to save the final student model
              "device": str,                 # 'cpu' or 'cuda'
              "student_model_name": str,     # e.g. 'maxvit_t'
              "pretrained": bool,
              "dropout_rate": float or None,
              "use_checkpoint": bool,
              "alpha": float,                # Weight for real-label CE vs. distillation
              "temperature": float,          # Distillation temperature
              "lr": float,                   # Learning rate
              "epochs": int,                 # Training epochs
              "image_size": int,
              "batch_size": int,
              "classes": tuple,              # e.g. ('nc','pc')
              "n_jobs": int or None,
              "validation_split": float,
              "pin_memory": bool,
              "normalize": bool,
              "channels": tuple,             # e.g. ('r','g','b')
              "augment": bool,
              "verbose": bool
            }

    Returns:
        (student_model, train_fig):
            student_model: The final, trained (distilled) TorchModel
            train_fig:     A figure from generate_loaders() if any (or None)
    """

    # Extract arguments from the settings dict, with defaults or fallback
    teacher_paths      = settings.get('teacher_paths', [])
    src                = settings.get('src', '')
    student_save_path  = settings.get('student_save_path', 'distilled_student.pth')
    device             = settings.get('device', 'cpu')
    student_model_name = settings.get('student_model_name', 'maxvit_t')
    pretrained         = settings.get('pretrained', True)
    dropout_rate       = settings.get('dropout_rate', None)
    use_checkpoint     = settings.get('use_checkpoint', False)
    alpha              = settings.get('alpha', 0.5)
    temperature        = settings.get('temperature', 2.0)
    lr                 = settings.get('lr', 1e-4)
    epochs             = settings.get('epochs', 10)
    image_size         = settings.get('image_size', 224)
    batch_size         = settings.get('batch_size', 32)
    classes            = settings.get('classes', ('nc', 'pc'))
    n_jobs             = settings.get('n_jobs', None)
    validation_split   = settings.get('validation_split', 0.0)
    pin_memory         = settings.get('pin_memory', False)
    normalize          = settings.get('normalize', False)
    channels           = settings.get('channels', ('r','g','b'))
    augment            = settings.get('augment', False)
    verbose            = settings.get('verbose', False)

    # -- 1) generate_loaders (returns train_loaders, val_loaders, train_fig) --
    from spacr.io import generate_loaders
    print("Generating training DataLoader(s) from:", src)
    train_loaders, val_loaders, train_fig = generate_loaders(
        src=src,
        mode='train',
        image_size=image_size,
        batch_size=batch_size,
        classes=classes,
        n_jobs=n_jobs,
        validation_split=validation_split,
        pin_memory=pin_memory,
        normalize=normalize,
        channels=channels,
        augment=augment,
        verbose=verbose
    )

    # If validation_split > 0, train_loaders is the train set, val_loaders is val set
    # Otherwise train_loaders is a single DataLoader.
    if validation_split > 0.0:
        print("Note: We'll only use the train_loader for knowledge distillation, ignoring val_loader.")
        train_loader = train_loaders
    else:
        train_loader = train_loaders

    # -- 2) define the internal knowledge-distillation function --
    def _knowledge_transfer(
        teacher_paths,
        student_save_path,
        data_loader,
        device='cpu',
        student_model_name='maxvit_t',
        pretrained=True,
        dropout_rate=None,
        use_checkpoint=False,
        alpha=0.5,
        temperature=2.0,
        lr=1e-4,
        epochs=10
    ):
        """
        Performs multi-teacher knowledge distillation on a new labeled dataset,
        producing a single student TorchModel that combines the teachers' knowledge
        plus the labeled data.
        """
        from spacr.utils import TorchModel  # or wherever TorchModel is located

        # Adjust filename to reflect KD if desired
        base, ext = os.path.splitext(student_save_path)
        if not ext:
            ext = '.pth'
        student_save_path = f"{base}_KD{ext}"

        # 1) Load teacher models
        teachers = []
        print("Loading teacher models:")
        for path in teacher_paths:
            print(f"  Loading teacher: {path}")
            ckpt = torch.load(path, map_location=device)
            if isinstance(ckpt, TorchModel):
                teacher = ckpt.to(device)
            elif isinstance(ckpt, dict):
                # create a new TorchModel with possible metadata
                teacher = TorchModel(
                    model_name=ckpt.get('model_name', student_model_name),
                    pretrained=ckpt.get('pretrained', pretrained),
                    dropout_rate=ckpt.get('dropout_rate', dropout_rate),
                    use_checkpoint=ckpt.get('use_checkpoint', use_checkpoint)
                ).to(device)
                teacher.load_state_dict(ckpt['model'])
            else:
                raise ValueError(f"Unsupported checkpoint type at {path} (must be TorchModel or dict).")

            teacher.eval()  # freeze teacher in eval mode
            teachers.append(teacher)

        # 2) Initialize the student TorchModel
        student_model = TorchModel(
            model_name=student_model_name,
            pretrained=pretrained,
            dropout_rate=dropout_rate,
            use_checkpoint=use_checkpoint
        ).to(device)

        # 3) Setup optimizer
        optimizer = optim.Adam(student_model.parameters(), lr=lr)

        # 4) Distillation training loop
        for epoch in range(epochs):
            student_model.train()
            running_loss = 0.0

            for images, labels in data_loader:
                images = images.to(device)
                labels = labels.to(device)

                # Forward pass (student)
                logits_s = student_model(images)
                logits_s_temp = logits_s / temperature

                # Teacher ensemble output
                with torch.no_grad():
                    teacher_probs_list = []
                    for tm in teachers:
                        logits_t = tm(images) / temperature
                        teacher_probs_list.append(F.softmax(logits_t, dim=1))
                    teacher_probs_ensemble = torch.mean(torch.stack(teacher_probs_list), dim=0)

                student_log_probs = F.log_softmax(logits_s_temp, dim=1)

                # Distillation loss
                loss_distill = F.kl_div(
                    student_log_probs,
                    teacher_probs_ensemble,
                    reduction='batchmean'
                ) * (temperature ** 2)

                # Supervised CE with ground-truth labels
                loss_ce = F.cross_entropy(logits_s, labels)

                # Weighted total loss
                loss = alpha * loss_ce + (1 - alpha) * loss_distill

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()

            avg_loss = running_loss / len(data_loader)
            print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}")

        # 5) Save the final student
        torch.save(student_model, student_save_path)
        print(f"Knowledge-distilled student saved to: {student_save_path}")

        return student_model

    # -- 3) Perform knowledge distillation using the internal function --
    distilled_student = _knowledge_transfer(
        teacher_paths=teacher_paths,
        student_save_path=student_save_path,
        data_loader=train_loader,
        device=device,
        student_model_name=student_model_name,
        pretrained=pretrained,
        dropout_rate=dropout_rate,
        use_checkpoint=use_checkpoint,
        alpha=alpha,
        temperature=temperature,
        lr=lr,
        epochs=epochs
    )

    print("Distillation complete. Student model returned.")
    return distilled_student, train_fig


teacher_paths = ['/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240810_132824/plate1/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth',
              '/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240824_072829/plate2/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth',
              '/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240825_094106/plate3/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth',
              '/nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240826_140251/plate4/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth']

src = '/nas_mnt/carruthers/Einar/tsg101_screen/hits_20250108_181547/plate2/datasets/training'

settings = {
    'teacher_paths':teacher_paths,
    'src':src,
    'student_save_path':'/nas_mnt/carruthers/Einar/tsg101_screen/hits_20250108_181547/plate2/kt_model',
    'device':'cpu',
    'student_model_name':'maxvit_t',
    'pretrained':True,
    'dropout_rate':None,
    'use_checkpoint':False,
    'alpha':0.5,
    'temperature':2.0,
    'lr':1e-4,
    'epochs':10,
    'image_size':224,
    'batch_size':64,
    'classes':('nc','pc'),
    'n_jobs':None,
    'validation_split':0.0,
    'pin_memory':False,
    'normalize':False,
    'channels':('r','g','b'),
    'augment':False,
    'verbose':False}

distilled_model, train_fig = model_knowledge_transfer(settings)



Generating training DataLoader(s) from: /nas_mnt/carruthers/Einar/tsg101_screen/hits_20250108_181547/plate2/datasets/training
Loading Train and validation datasets
Loading teacher models:
  Loading teacher: /nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240810_132824/plate1/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth
  Loading teacher: /nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240824_072829/plate2/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth
  Loading teacher: /nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240825_094106/plate3/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth
  Loading teacher: /nas_mnt/carruthers/Einar/tsg101_screen/TSG101SCREEN_20240826_140251/plate4/datasets/training/model/maxvit_t/rgb/epochs_100/maxvit_t_epoch_100_channels_rgb.pth


ValueError: too many values to unpack (expected 2)