In [2]:
import os
from PIL import Image
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List
import random
from torchvision import transforms
import numpy as np

# -- To insert non functional code in the image to get augmented image --
class CodeInserter:
    def __init__(self, image: Image.Image):
        self.image = image
        self.flat_img = np.array(image).ravel()
        self.opcodes = [0, 144, 204, 205]
        self.patterns = [
            [0],              # Simulates null padding
            [144],            # Simulates single-byte NOP
            [204],            # Simulates debug breakpoint
            #[255],            # Simulates filler bytes
            [102, 144]        # Simulates 2-byte NOP (0x66 0x90)
        ]
    def _generate_random_slice(self, min_len=40, max_len=400):
        """
        Generates a slice of noise using one of several realistic strategies.
        """
        slice_length = random.randint(min_len, max_len)
        # Randomly choose a generation mode for this slice
        generation_mode = random.choice(['pattern', 'structured_noise'])
        if generation_mode == 'pattern':
            # --- Strategy 1: Repeat a specific, ordered pattern ---
            chosen_pattern = random.choice(self.patterns)
            num_repeats = (slice_length // len(chosen_pattern)) + 1# Tile the pattern until it's long enough
            tiled_pattern = np.tile(chosen_pattern, num_repeats)
            return tiled_pattern[:slice_length]# Trim to the exact slice_length
        else:
            # --- Strategy 2: Create a random mix of opcodes ---
            num_opcodes = random.randint(2, len(self.opcodes))
            chosen_opcodes = random.sample(self.opcodes, num_opcodes)
            return np.random.choice(chosen_opcodes, size=slice_length)
    def _calculate_new_dimensions(self, new_pixel_count: int):
        """
        Calculates new width and height, preserving aspect ratio.
        This is a simpler, more direct way to do the math.
        """
        og_w, og_h = self.image.size
        aspect_ratio = og_w / og_h
        new_h = int(np.sqrt(new_pixel_count / aspect_ratio))
        new_w = int(new_h * aspect_ratio)
        # Pad to ensure the new canvas is big enough
        while new_w * new_h < new_pixel_count:
            new_w += 1
        return new_w, new_h
    def augment(self):
        """
        Applies the insertion augmentation and returns a new PIL Image.
        """
        augmented_image = self.flat_img.copy()
        num_insertions = random.randint(1, 3)
        max_idx = len(augmented_image)
        cut_points = np.random.choice(max_idx, size=num_insertions, replace=False)
        cut_points.sort()
        for point in reversed(cut_points):
            insertion_slice = self._generate_random_slice()
            augmented_image = np.insert(augmented_image, point, insertion_slice)
        new_w, new_h = self._calculate_new_dimensions(len(augmented_image))
        pad_len = (new_w * new_h) - len(augmented_image)# pad the array to fit the new rectangular shape
        padded_image = np.pad(augmented_image, (0, pad_len), 'constant')
        reshaped_image = padded_image.reshape((new_h, new_w)).astype(np.uint8)
        final_image = Image.fromarray(reshaped_image) # construct PIL image
        #final_image.resize((64, 64), Image.Resampling.LANCZOS)
        return final_image.resize((64, 64), Image.Resampling.LANCZOS)

# -- To duplicate functional code in the image to get augmented image --
class CodeDuplicator:
    def __init__(self, image: Image.Image):
        self.image = image
        self.flat_img = np.array(image).ravel()
        self.img_height, self.img_width = image.height, image.width

    def _get_duplication_parameters(self):
        # Using a square patch with a side length between 8 and 20 pixels.
        side_length = random.randint(8, 20)
        #window = side_length * side_length
        window =int(0.1 * self.flat_img.size)  # 10 % of the input image size

        max_start = len(self.flat_img) - window
        start_point = random.randint(0, max_start)
        while True:
            insertion_point = random.randint(0, max_start)
            if abs(start_point - insertion_point) >= window:
                break # found valid
        return start_point, insertion_point, window

    def augment(self):
        """
        Applies the duplication augmentation and returns a new PIL Image.
        This is the main public method to call.
        """
        start, duplicate_at, window = self._get_duplication_parameters()
        augmented_flat = self.flat_img.copy()
        snippet = self.flat_img[start : start + window]
        augmented_flat[duplicate_at : duplicate_at + window] = snippet
        augmented_2d = np.reshape(augmented_flat, (self.img_height, self.img_width))
        return Image.fromarray(augmented_2d.astype(np.uint8))

# -- Sampling and creation of batches --
class TaskSampler:
    def __init__(self, resized_folder: str, families: list, meta_batch_size: int, n_way: int, k_shot: int, q_query: int):
        self.resized_folder = resized_folder
        self.families = families
        self.meta_batch_size = meta_batch_size
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        # Pre-load all image paths for efficiency
        self.image_paths_by_family = {
            f: [os.path.join(resized_folder, f, img_name) for img_name in os.listdir(os.path.join(resized_folder, f))]
            for f in self.families
        }
        # Define a single transform to convert PIL Images to PyTorch Tensors
        self.to_tensor = transforms.ToTensor() # output on using on one image is (1, H, W)
    def _create_input_tensor(self, base_img: Image.Image) -> torch.Tensor:
        """Creates the 3-channel tensor from a single PIL image."""
        # Instantiate your augmentation classes for each image
        aug1_img = CodeInserter(base_img).augment()
        aug2_img = CodeDuplicator(base_img).augment()
        base_tensor = self.to_tensor(base_img)
        aug1_tensor = self.to_tensor(aug1_img)
        aug2_tensor = self.to_tensor(aug2_img)
        # Note: to_tensor creates a (1, H, W) tensor, so we cat on dim 0
        return torch.cat([base_tensor, aug1_tensor, aug2_tensor], dim=0)
    def sample(self):
        """
        Samples a full meta-batch of tasks.
        Returns:
            A tuple of (support_x, support_y, query_x, query_y) tensors.
        """
        support_x_batch, support_y_batch = [], []
        query_x_batch, query_y_batch = [], []

        for _ in range(self.meta_batch_size):
            support_x, support_y = [], []
            query_x, query_y = [], []
            task_families = random.sample(self.families, self.n_way) # has the names of the families present in this task

            for i, family in enumerate(task_families):
                all_paths = self.image_paths_by_family[family] # has all possible img path for specific family
                sampled_paths = random.sample(all_paths, self.k_shot + self.q_query)
                # Load images, augment, and convert to tensors
                class_tensors = [self._create_input_tensor(Image.open(p)) for p in sampled_paths]
                class_tensors = torch.stack(class_tensors) # Shape: (k+q, 3, 64, 64)

                # Split into support and query sets
                support_set = class_tensors[:self.k_shot]
                query_set = class_tensors[self.k_shot:]
                support_x.append(support_set)
                query_x.append(query_set)
                # Create labels
                support_y.append(torch.full((self.k_shot,), i, dtype=torch.long))
                query_y.append(torch.full((self.q_query,), i, dtype=torch.long))

            # Aggregate all classes for this one task
            support_x_batch.append(torch.cat(support_x, dim=0))
            support_y_batch.append(torch.cat(support_y, dim=0))
            query_x_batch.append(torch.cat(query_x, dim=0))
            query_y_batch.append(torch.cat(query_y, dim=0))

        # Stack all tasks to create the final meta-batch
        return (torch.stack(support_x_batch), torch.stack(support_y_batch),
                torch.stack(query_x_batch), torch.stack(query_y_batch))

# -- Hyper-parameters for the model --
@dataclass
class Hyperparameters:
    n_way : int # total number of classes we want to classify in
    k_shot : int # support examples in each task
    q_query : int # query examples for each task
    inner_lr : float # inner loop (task specific) learning rate (alpha in the original paper)
    meta_lr : float # meta-learning rate (beta in the original paper)
    meta_batch_size : int # total number of tasks per meta-batch
    traning_steps : int # number of meta-updates to perform
    inner_steps : int # number of gradient steps ('n': from the original paper)
    meta_training_families : List # the list of tasks (families) from which we sample
    meta_testing_families : List # the list of tasks (families) which are unknown and used for testing
    meta_val_families : List # the list of tasks (families) for validation

# -- Handle the updates in Learning Rate --
class LearningRate:
    def __init__(self, initial_inner_lr: int):
        self.alpha1 = initial_inner_lr
    def _upate_alpha(self, inner_lr: int, inner_steps: int, task_num: int):
        alpha_i = inner_lr
        n = inner_steps
        delta1 = self.alpha1/4
        delta2 = self.alpha1 / (n +1)
        if 1<= task_num < n +1: return alpha_i + delta1
        if n+1<=task_num <2*n+1: return alpha_i - delta1
        if 2*n+1<=task_num<3*n+1: return alpha_i + delta2
        if 3*n+1<=task_num<4*n+1: return alpha_i - delta2
    def _update_beta(self, meta_lr: int, inner_steps: int):
        return meta_lr / inner_steps

# -- Handles the accurate update of both learning rates --
class LearningRateManager:
    def __init__(self, initial_lr, meta_optimizer, inner_steps):
        self.inner_lr = initial_lr
        self.initial_lr = initial_lr
        self.meta_optimizer = meta_optimizer
        self.inner_steps = inner_steps
        self.val_loss_plateau_counter = 0
    def _get_inner_lr(self, task_idx):
        # Implements the DILLR
        period = 4*self.inner_steps
        n= self.inner_steps
        i = task_idx % period
        delta1 = self.initial_lr/4
        delta2 = self.initial_lr/(n+1)
        if 0<=i<n: self.inner_lr+=delta1
        if n<=i<2*n: self.inner_lr-=delta1
        if 2*n<=i<3*n: self.inner_lr+=delta2
        if 3*n<=i<4*n: self.inner_lr-=delta2
        return self.inner_lr
    def _check_plateau(self, current_val_loss, best_val_loss):
        # Implements the AOLLR for the meta_lr
        if current_val_loss >= best_val_loss:
            self.val_loss_plateau_counter+=1
        else:
            self.val_loss_plateau_counter = 0 # reset as better validation performance
        if self.val_loss_plateau_counter >= self.inner_steps:
            print(f'Validation loss has plateaued. Reducing meta_lr')
            for g in self.meta_optimizer.param_groups:
                g['lr'] *= (1 / self.inner_steps)
            self.val_loss_plateau_counter = 0 # reset after reducing

# -- Base Learner Model --
class CNN(nn.Module):
    def __init__(self, in_channel: int, n_way: int):
        super(CNN, self).__init__()
        # Define the 4 CNN blocks
        self.layer1 = self._make_conv_block(in_channel, 64)
        self.layer2 = self._make_conv_block(64,64)
        self.layer3 = self._make_conv_block(64,64)
        self.layer4 = self._make_conv_block(64,64)
        self.dropout = nn.Dropout(0.25),
        self.classifier = nn.Linear(64*4*4, n_way)
    def _make_conv_block(self, in_channels: int, out_channels: int):
        """ Helper function to make convolutional block
        Parameters
        ----------
        in_channels : int
            the input channels for the conv layer
        out_channels : int
            the output channel of the layer (total number of filters applied)
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ELU(),
            nn.MaxPool2d(2)
        )
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        # First Group
        x = self.layer1(x)
        x= self.layer2(x)
        x = self.dropout(x)
        # Second Group
        x = self.laye3(x)
        x = self.layer4(x)
        x = self.dropout(x)
        # Flatten the output for the linear layer
        x = x.view(x.size(0), -1)
        # Final logits
        logits = self.classifier(x)
        return logits


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
print(f'Using: {device}')


Using: mps


In [5]:
# Set hyper parameters
hp = Hyperparameters(
    n_way=5,
    k_shot=2,
    q_query=15,
    inner_lr=0.01,
    meta_lr=0.001,
    meta_batch_size=16,
    traning_steps=100*100,
    inner_steps=5,
    meta_training_families=['Adialer.C', 'Agent.FYI', 'Allaple.A', 'Alueron.gen!J', 'Autorun.K', 'C2LOP.gen!g', 'Lolyda.AA2', 'Lolyda.AT',
                        'Rbot!gen', 'Skintrim.N', 'Swizzor.gen!E', 'Swizzor.gen!I', 'VB.AT', 'Wintrim.BX', 'Yuner.A'],
    meta_val_families=['Dontovo.A', 'Fakerean', 'Instantaccess', 'Lolyda.AA3', 'Obfuscator.AD'],
    meta_testing_families=['Dialplatform.B', 'Allaple.L', 'C2LOP.P', 'Lolyda.AA1', 'Malex.gen!J']
)
# Set Task Sampler
sampler = TaskSampler(
    resized_folder='malimg_resized',
    families=hp.meta_training_families,
    meta_batch_size=hp.meta_batch_size,
    n_way=hp.n_way,
    k_shot=hp.k_shot,
    q_query=hp.q_query
)
# Initialize the model and set it to device
meta_model = CNN(in_channel=3, n_way=hp.n_way).to(device=device)
meta_optimizer = optim.Adam(meta_model.parameters(), lr=hp.meta_lr)
# Set Learning Rate
lr_manager = LearningRateManager(initial_lr=hp.inner_lr, meta_optimizer=meta_optimizer, inner_steps=hp.inner_steps)


In [None]:
# Function to evaluate the model
def evaluate(model, families, n_way, k_shot, q_query, num_eval_tasks=100):
    model.eval()
    total_correct = 0
    total_samples = 0
    # temp sampler for evaluation
    eval_sampler = TaskSampler(resized_folder='malimg_resized', families=families,
                               meta_batch_size=num_eval_tasks, n_way=n_way, k_shot=k_shot, q_query=q_query)
    support_x, support_y, query_x, query_y = eval_sampler.sample()
    support_x, support_y = support_x.to(device=device), support_y.to(device=device)
    query_x, query_y = query_x.to(device=device), query_y.to(device=device)

    for i in range(num_eval_tasks):
        with torch.no_grad():
            fast_model = copy.deepcopy(model)
            sx, sy = support_x[i], support_y[i]
            qx, qy = query_x[i], query_y[i]

            # Adapt model on the support (no optimizer needed)
            for _ in range(hp.inner_steps):
                logits = fast_model(sx)
                loss = F.cross_entropy(logits, sy)
                grads = torch.autograd.grad(loss, fast_model.parameters())
                # Update the inner loop
                for p, g in zip(fast_model.parameters(), grads):
                    p.data -= hp.inner_lr*g
                # Evaluate on query
                query_logits = fast_model(qx)
                preds = torch.argmax(query_logits, dim=1)
                total_correct+= (preds==qy).sum().item()
                total_samples+= len(qy)
        return total_correct / total_samples


In [None]:
import time
from tqdm import tqdm

#Logging and Chekpoint variables
best_val_acc = 0.0
steps_per_epoch = 100
print(f'Starting thr Mi-MAML training on : {device}')
for step in tqdm(range(hp.traning_steps), desc="Mi-MAML training"):
    meta_model.train()
    support_x, support_y, query_x, query_y = sampler.sample()
    support_x, support_y = support_x.to(device=device), support_y.to(device=device)
    query_x, query_y = query_x.to(device=device), query_y.to(device=device)

    meta_optimizer.zero_grad()
    total_query_loss = 0.0

    for i in range(hp.meta_batch_size):
        fast_model = copy.deepcopy(meta_model)
        sx, sy = support_x[i], support_y[i]
        qx, qy = query_x[i], query_y[i]
        # Inner loop adaptation FOMAML
        for j in range(hp.inner_steps):
            logits = fast_model(sx)
            loss = F.cross_entropy(logits, sy)
            grads = torch.autograd.grad(loss, fast_model.parameters(), create_graph=False)
            # Update the fast_model parameters
            current_inner_lr = lr_manager._get_inner_lr(step*hp.meta_batch_size+i)
            for p, g in zip(fast_model.parameters(), grads):
                p.data -= current_inner_lr*g
        # Calculate the loss on query
        query_logits = fast_model(qx)
        query_loss = F.cross_entropy(query_logits, qy)
        total_query_loss+= query_loss
    # Average the loss across all the tasks and do meta-update
    average_meta_ls = total_query_loss / hp.meta_batch_size
    average_meta_ls.backward()
    meta_optimizer.step()

    # Validation and checkpointing
    if (step+1) % steps_per_epoch == 0:
        epoch = (step + 1) // steps_per_epoch
        val_acc = evaluate (meta_model, hp.meta_val_families, hp.n_way, hp.k_shot, hp.q_query)
        print(f"Epoch {epoch}/{hp.traning_steps // steps_per_epoch} | Meta Loss: {average_meta_ls.item():.4f} | Val Acc: {val_acc:.4f}")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            print(f"  -> New best validation accuracy! Saving model...")
            torch.save(meta_model.state_dict(), "best_maml_model.pth")
        # The paper doesn't mention using validation loss for AOLLR, but it's a common practice.
        # lr_manager.check_plateau(current_val_loss, best_val_loss)


In [None]:
# --- MISSING PIECE: Final Testing ---
print("\nTraining finished. Loading best model and evaluating on the test set.")
meta_model.load_state_dict(torch.load("best_maml_model.pth"))
test_accuracy = evaluate(meta_model, hp.meta_testing_families, hp.n_way, hp.k_shot, hp.q_query)
print(f"\nFinal Test Accuracy: {test_accuracy:.4f}")