In [1]:
import os
from PIL import Image
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from dataclasses import dataclass
from typing import List
import random
from torchvision import transforms

# -- To insert non functional code in the image to get augmented image --
class CodeInserter:
    def __init__(self, image: Image.Image):
        self.image = image
        self.flat_img = np.array(image).ravel()
        self.opcodes = [0, 144, 204, 205]
        self.patterns = [
            [0],              # Simulates null padding
            [144],            # Simulates single-byte NOP
            [204],            # Simulates debug breakpoint
            #[255],            # Simulates filler bytes
            [102, 144]        # Simulates 2-byte NOP (0x66 0x90)
        ]
    def _generate_random_slice(self, min_len=40, max_len=400):
        """
        Generates a slice of noise using one of several realistic strategies.
        """
        slice_length = random.randint(min_len, max_len)
        # Randomly choose a generation mode for this slice
        generation_mode = random.choice(['pattern', 'structured_noise'])
        if generation_mode == 'pattern':
            # --- Strategy 1: Repeat a specific, ordered pattern ---
            chosen_pattern = random.choice(self.patterns)
            num_repeats = (slice_length // len(chosen_pattern)) + 1# Tile the pattern until it's long enough
            tiled_pattern = np.tile(chosen_pattern, num_repeats)
            return tiled_pattern[:slice_length]# Trim to the exact slice_length
        else:
            # --- Strategy 2: Create a random mix of opcodes ---
            num_opcodes = random.randint(2, len(self.opcodes))
            chosen_opcodes = random.sample(self.opcodes, num_opcodes)
            return np.random.choice(chosen_opcodes, size=slice_length)
    def _calculate_new_dimensions(self, new_pixel_count: int):
        """
        Calculates new width and height, preserving aspect ratio.
        This is a simpler, more direct way to do the math.
        """
        og_w, og_h = self.image.size
        aspect_ratio = og_w / og_h
        new_h = int(np.sqrt(new_pixel_count / aspect_ratio))
        new_w = int(new_h * aspect_ratio)
        # Pad to ensure the new canvas is big enough
        while new_w * new_h < new_pixel_count:
            new_w += 1
        return new_w, new_h
    def augment(self):
        """
        Applies the insertion augmentation and returns a new PIL Image.
        """
        augmented_image = self.flat_img.copy()
        num_insertions = random.randint(1, 3)
        max_idx = len(augmented_image)
        cut_points = np.random.choice(max_idx, size=num_insertions, replace=False)
        cut_points.sort()
        for point in reversed(cut_points):
            insertion_slice = self._generate_random_slice()
            augmented_image = np.insert(augmented_image, point, insertion_slice)
        new_w, new_h = self._calculate_new_dimensions(len(augmented_image))
        pad_len = (new_w * new_h) - len(augmented_image)# pad the array to fit the new rectangular shape
        padded_image = np.pad(augmented_image, (0, pad_len), 'constant')
        reshaped_image = padded_image.reshape((new_h, new_w)).astype(np.uint8)
        final_image = Image.fromarray(reshaped_image) # construct PIL image
        #final_image.resize((64, 64), Image.Resampling.LANCZOS)
        return final_image.resize((64, 64), Image.Resampling.LANCZOS)

# -- To duplicate functional code in the image to get augmented image --
class CodeDuplicator:
    def __init__(self, image: Image.Image):
        self.image = image
        self.flat_img = np.array(image).ravel()
        self.img_height, self.img_width = image.height, image.width

    def _get_duplication_parameters(self):
        # Using a square patch with a side length between 8 and 20 pixels.
        side_length = random.randint(8, 20)
        #window = side_length * side_length
        window =int(0.1 * self.flat_img.size)  # 10 % of the input image size

        max_start = len(self.flat_img) - window
        start_point = random.randint(0, max_start)
        while True:
            insertion_point = random.randint(0, max_start)
            if abs(start_point - insertion_point) >= window:
                break # found valid
        return start_point, insertion_point, window

    def augment(self):
        """
        Applies the duplication augmentation and returns a new PIL Image.
        This is the main public method to call.
        """
        start, duplicate_at, window = self._get_duplication_parameters()
        augmented_flat = self.flat_img.copy()
        snippet = self.flat_img[start : start + window]
        augmented_flat[duplicate_at : duplicate_at + window] = snippet
        augmented_2d = np.reshape(augmented_flat, (self.img_height, self.img_width))
        return Image.fromarray(augmented_2d.astype(np.uint8))

# -- Sampling and creation of batches --
class TaskSampler:
    def __init__(self, resized_folder: str, families: list, meta_batch_size: int, n_way: int, k_shot: int, q_query: int):
        self.resized_folder = resized_folder
        self.families = families
        self.meta_batch_size = meta_batch_size
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        # Pre-load all image paths for efficiency
        self.image_paths_by_family = {
            f: [os.path.join(resized_folder, f, img_name) for img_name in os.listdir(os.path.join(resized_folder, f))]
            for f in self.families
        }
        # Define a single transform to convert PIL Images to PyTorch Tensors
        self.to_tensor = transforms.ToTensor() # output on using on one image is (1, H, W)
    def _create_input_tensor(self, base_img: Image.Image) -> torch.Tensor:
        """Creates the 3-channel tensor from a single PIL image."""
        # Instantiate your augmentation classes for each image
        aug1_img = CodeInserter(base_img).augment()
        aug2_img = CodeDuplicator(base_img).augment()
        base_tensor = self.to_tensor(base_img)
        aug1_tensor = self.to_tensor(aug1_img)
        aug2_tensor = self.to_tensor(aug2_img)
        # Note: to_tensor creates a (1, H, W) tensor, so we cat on dim 0
        return torch.cat([base_tensor, aug1_tensor, aug2_tensor], dim=0)
    def sample(self):
        """
        Samples a full meta-batch of tasks.
        Returns:
            A tuple of (support_x, support_y, query_x, query_y) tensors.
        """
        support_x_batch, support_y_batch = [], []
        query_x_batch, query_y_batch = [], []

        for _ in range(self.meta_batch_size):
            support_x, support_y = [], []
            query_x, query_y = [], []
            task_families = random.sample(self.families, self.n_way) # has the names of the families present in this task

            for i, family in enumerate(task_families):
                all_paths = self.image_paths_by_family[family] # has all possible img path for specific family
                sampled_paths = random.sample(all_paths, self.k_shot + self.q_query)
                # Load images, augment, and convert to tensors
                class_tensors = [self._create_input_tensor(Image.open(p)) for p in sampled_paths]
                class_tensors = torch.stack(class_tensors) # Shape: (k+q, 3, 64, 64)

                # Split into support and query sets
                support_set = class_tensors[:self.k_shot]
                query_set = class_tensors[self.k_shot:]
                support_x.append(support_set)
                query_x.append(query_set)
                # Create labels
                support_y.append(torch.full((self.k_shot,), i, dtype=torch.long))
                query_y.append(torch.full((self.q_query,), i, dtype=torch.long))

            # Aggregate all classes for this one task
            support_x_batch.append(torch.cat(support_x, dim=0))
            support_y_batch.append(torch.cat(support_y, dim=0))
            query_x_batch.append(torch.cat(query_x, dim=0))
            query_y_batch.append(torch.cat(query_y, dim=0))

        # Stack all tasks to create the final meta-batch
        return (torch.stack(support_x_batch), torch.stack(support_y_batch),
                torch.stack(query_x_batch), torch.stack(query_y_batch))

# -- Hyper-parameters for the model --
@dataclass
class Hyperparameters:
    n_way : int # total number of classes we want to classify in
    k_shot : int # support examples in each task
    q_query : int # query examples for each task
    inner_lr : float # inner loop (task specific) learning rate (alpha in the original paper)
    meta_lr : float # meta-learning rate (beta in the original paper)
    meta_batch_size : int # total number of tasks per meta-batch
    traning_steps : int # number of meta-updates to perform
    inner_steps : int # number of gradient steps ('n': from the original paper)
    meta_training_families : List # the list of tasks (families) from which we sample
    meta_testing_families : List # the list of tasks (families) which are unknown and used for testing
    meta_val_families : List # the list of tasks (families) for validation

# -- Handle the updates in Learning Rate --
class LearningRate:
    def __init__(self, initial_inner_lr: int):
        self.alpha1 = initial_inner_lr
    def _upate_alpha(self, inner_lr: int, inner_steps: int, task_num: int):
        alpha_i = inner_lr
        n = inner_steps
        delta1 = self.alpha1/4
        delta2 = self.alpha1 / (n +1)
        if 1<= task_num < n +1: return alpha_i + delta1
        if n+1<=task_num <2*n+1: return alpha_i - delta1
        if 2*n+1<=task_num<3*n+1: return alpha_i + delta2
        if 3*n+1<=task_num<4*n+1: return alpha_i - delta2
    def _update_beta(self, meta_lr: int, inner_steps: int):
        return meta_lr / inner_steps

# -- Base Learner Model --
class CNN(nn.Module):
    def __init__(self, in_channel: int, n_way: int):
        super(CNN, self).__init__()
        # Define the 4 CNN blocks
        self.layer1 = self._make_conv_block(in_channel, 64)
        self.layer2 = self._make_conv_block(64,64)
        self.layer3 = self._make_conv_block(64,64)
        self.layer4 = self._make_conv_block(64,64)
        self.dropout = nn.Dropout(0.25),
        self.classifier = nn.Linear(64*4*4, n_way)
    def _make_conv_block(self, in_channels: int, out_channels: int):
        """ Helper function to make convolutional block
        Parameters
        ----------
        in_channels : int
            the input channels for the conv layer
        out_channels : int
            the output channel of the layer (total number of filters applied)
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ELU(),
            nn.MaxPool2d(2)
        )
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        # First Group
        x = self.layer1(x)
        x= self.layer2(x)
        x = self.dropout(x)
        # Second Group
        x = self.laye3(x)
        x = self.layer4(x)
        x = self.dropout(x)
        # Flatten the output for the linear layer
        x = x.view(x.size(0), -1)
        # Final logits
        logits = self.classifier(x)
        return logits


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
print(f'Using: {device}')


Using: mps


In [None]:
# Set hyper parameters
hp = Hy