In [None]:
# installation
!pip install pandas
!pip install torchinfo
!pip install tqdm
!pip install torch

In [None]:
!curl -L -u $username:$token\
  -o benign-dataset.zip\
  https://www.kaggle.com/api/v1/datasets/download/aquacoder/benign-dataset

In [None]:
!mkdir -p dataset
!unzip -q benign-dataset.zip -d dataset

# Dataset.py


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def get_dataloaders(image_size, batch_size):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) # Scales to [-1, 1] for Tanh
    ])

    # Train: Expects ./train/some_class/image.png
    train_dataset = datasets.ImageFolder(root='./dataset/train', transform=transform)

    if 'benign' in train_dataset.classes:
        benign_class_idx = train_dataset.class_to_idx['benign']
        
        # Filter samples to include only those belonging to the 'benign' class
        train_dataset.samples = [
            (path, label) for path, label in train_dataset.samples
            if label == benign_class_idx
        ]
        # The new class list for the train set is now just ['benign']
        print(f"Training on {len(train_dataset.samples)} samples from the 'benign' class.")
    else:
        # If 'benign' is not found, the FileNotFoundError was probably correct.
        raise FileNotFoundError(f"Could not find the expected 'benign' class folder inside {train_root_path}")
    
    # Test: Expects ./test/benign/ and ./test/malware/
    test_dataset = datasets.ImageFolder(root='./dataset/test', transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, test_loader, test_dataset.classes

# Losses.py

In [None]:
import torch
import torch.nn as nn
import torch.autograd as autograd

def l1(x, y):
    """
    Computes L1 distance keeping batch dimension.
    """
    x = x.view(x.size(0), -1)
    y = y.view(y.size(0), -1)
    return torch.sum(torch.abs(x - y), dim=1)

def gradient_penalty(discriminator, x, x_gen, z, z_gen, device):
    """
    Calculates the WGAN-GP gradient penalty.
    """
    batch_size = x.size(0)
    
    alpha = torch.rand(batch_size, 1, device=device)
    alpha_img = alpha.view(batch_size, 1, 1, 1) 
    alpha_z = alpha.view(batch_size, 1)

    x_hat = (alpha_img * x + (1 - alpha_img) * x_gen).detach().requires_grad_(True)
    z_hat = (alpha_z * z + (1 - alpha_z) * z_gen).detach().requires_grad_(True)
    score_hat = discriminator(x_hat, z_hat)

    gradients = autograd.grad(
        outputs=score_hat,
        inputs=[x_hat, z_hat],
        grad_outputs=torch.ones_like(score_hat),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )
    
    dx, dz = gradients
    
    dx = dx.view(dx.size(0), -1)
    dz = dz.view(dz.size(0), -1)
    
    grads = torch.cat([dx, dz], dim=1)
    grads_norm = torch.sqrt(torch.sum(grads ** 2, dim=1) + 1e-12)    
    norm_penalty = (grads_norm - 1) ** 2
    return norm_penalty.mean()


# Model.py

In [None]:
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import cv2

import numpy as np
import torch
import torch.nn as nn
from torchinfo import summary

def get_bn_layer(bn_type='none', num_features=None, dims=None):
    if bn_type == 'batch':
        return nn.BatchNorm1d(num_features) if dims == 1 else nn.BatchNorm2d(num_features)
    elif bn_type == 'layer':
        return nn.GroupNorm(1, num_features)
    elif bn_type == 'instance':
        return nn.InstanceNorm1d(num_features) if dims == 1 else nn.InstanceNorm2d(num_features, affine=True)
    elif bn_type == 'none':
        return nn.Identity()
    else:
        raise ValueError(f"Unsupported normalization layer type: {bn_type}")

class GBlock(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=True, upsample_type='bilinear', use_bias=True, bn_type='none', act=nn.ReLU):
        super().__init__()
        self.upsample = upsample
        
        if upsample:
            if upsample_type == 'bilinear':
                self.upsample_layer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
            else:
                self.upsample_layer = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=use_bias)
        self.bn1 = get_bn_layer(bn_type, out_channels)
        self.act1 = act()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=use_bias)
        self.bn2 = get_bn_layer(bn_type, out_channels)
        self.act2 = act()
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0)
        
        self.conv_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=use_bias)
        self.bn_skip = get_bn_layer(bn_type, out_channels)
        
        self.bn_final = get_bn_layer(bn_type, out_channels)
        self.act_final = act()

    def forward(self, x):
        if self.upsample:
            x = self.upsample_layer(x)
            
        skip = self.bn_skip(self.conv_skip(x))
        h = self.act1(self.bn1(self.conv1(x)))
        h = self.act2(self.bn2(self.conv2(h)))
        h = self.conv3(h)
        
        return self.act_final(self.bn_final(h + skip))

#DBlock(3, 1 * channels, use_bias=use_bias, bn_type=bn_type, act=act),        # -> 64x64 if image_size=128
class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool=True, use_bias=True, bn_type='none', act=nn.ReLU):
        super().__init__()
        self.pool = pool
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=use_bias)
        self.bn1 = get_bn_layer(bn_type, out_channels)
        self.act1 = act()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=use_bias)
        self.bn2 = get_bn_layer(bn_type, out_channels)
        self.act2 = act()
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=0)

        self.conv_skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=use_bias)
        self.bn_skip = get_bn_layer(bn_type, out_channels)

        self.bn_final = get_bn_layer(bn_type, out_channels)
        self.act_final = act()
        
        if self.pool:
            self.pool_layer = nn.AvgPool2d(2)

    def forward(self, x):
        skip = self.bn_skip(self.conv_skip(x))
        
        h = self.act1(self.bn1(self.conv1(x)))
        h = self.act2(self.bn2(self.conv2(h)))
        h = self.conv3(h)
        
        out = self.act_final(self.bn_final(h + skip))
        if self.pool:
            out = self.pool_layer(out)
        return out

class Generator(nn.Module):
    def __init__(self, latent_size, channels=3, upsample_first=True, upsample_type='bilinear', bn_type='none', act_type='lrelu'):
        super().__init__()
        
        use_bias = bn_type == 'none'
        act = lambda: nn.LeakyReLU(0.2, inplace=True) if act_type == 'lrelu' else nn.ReLU(inplace=True)
        
        self.init_dense = nn.Linear(latent_size, 2 * 2 * 32 * channels, bias=use_bias)
        self.init_bn = get_bn_layer(bn_type, num_features=2 * 2 * 32 * channels, dims=1)
        self.init_channels = 32 * channels
        
        self.g_block1 = GBlock(32 * channels, 32 * channels, upsample=upsample_first, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block2 = GBlock(32 * channels, 32 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block3 = GBlock(32 * channels, 16 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block4 = GBlock(16 * channels, 8 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block5 = GBlock(8 * channels, 4 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block6 = GBlock(4 * channels, 3 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block7 = GBlock(3 * channels, 2 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block8 = GBlock(2 * channels, 1 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)
        self.g_block9 = GBlock(1 * channels, 1 * channels, upsample_type=upsample_type, use_bias=use_bias, bn_type=bn_type, act=act)

        
        self.final_conv = nn.Conv2d(1 * channels, 3, kernel_size=1, padding=0)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.init_dense(x);
        x = self.init_bn(x);
        
        # Reshape to (Batch, Channels, Height, Width) for PyTorch conv layers
        x = x.view(-1, self.init_channels, 2, 2)
        x = self.g_block1(x)
        x = self.g_block2(x)
        x = self.g_block3(x)
        x = self.g_block4(x)
        x = self.g_block5(x)
        x = self.g_block6(x)
        x = self.g_block7(x)
        x = self.g_block8(x)
        x = self.g_block9(x)
        x = self.final_conv(x)

        return self.tanh(x)

class Encoder(nn.Module):
    def __init__(self, image_size, latent_size, channels=3, bn_type='none', act_type='lrelu'):
        super().__init__()
        use_bias = bn_type == 'none'
        act = lambda: nn.LeakyReLU(0.2, inplace=True) if act_type == 'lrelu' else nn.ReLU(inplace=True)

        self.blocks = nn.Sequential(
            DBlock(3, 1 * channels, use_bias=use_bias, bn_type=bn_type, act=act),        # -> 256x256
            DBlock(1 * channels, 2 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 128x128
            DBlock(2 * channels, 3 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 64x64
            DBlock(3 * channels, 4 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 32x32
            DBlock(4 * channels, 8 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 16x16
            DBlock(8 * channels, 16 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 8x8
            DBlock(16 * channels, 32 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 4x4
            DBlock(32 * channels, 32 * channels, pool=False, use_bias=use_bias, bn_type=bn_type, act=act) # -> 4x4
        )
        
        self.final_dense1 = nn.Linear(32 * channels * 4 * 4, 32 * channels * 2 * 2, bias=use_bias)
        self.final_bn = get_bn_layer(bn_type, num_features=32 * channels * 2 * 2, dims=1)
        self.final_act = act()
        self.final_dense2 = nn.Linear(32 * channels * 2 * 2, latent_size)

    def forward(self, x):
        x = self.blocks(x)
        x = torch.flatten(x, start_dim=1)
        x = self.final_act(self.final_bn(self.final_dense1(x)))
        x = self.final_dense2(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, image_size, latent_size, channels=3, bn_type='none', act_type='lrelu'):
        super().__init__()
        use_bias = bn_type == 'none'
        act = lambda: nn.LeakyReLU(0.2, inplace=True) if act_type == 'lrelu' else nn.ReLU(inplace=True)

        self.latent_path = nn.Sequential(
            nn.Linear(latent_size, 512, bias=use_bias),
            get_bn_layer(bn_type, 512, dims=1),
            act(),
            nn.Linear(512, 512, bias=use_bias),
            get_bn_layer(bn_type, 512, dims=1),
            act(),
            nn.Linear(512, 512, bias=use_bias),
            get_bn_layer(bn_type, 512, dims=1),
            act()
        )
        self.image_path = nn.Sequential(
            DBlock(3, 1 * channels, use_bias=use_bias, bn_type=bn_type, act=act),        # -> 256x256
            DBlock(1 * channels, 2 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 128x128
            DBlock(2 * channels, 3 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 64x64
            DBlock(3 * channels, 4 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 32x32
            DBlock(4 * channels, 8 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 16x16
            DBlock(8 * channels, 16 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 8x8
            DBlock(16 * channels, 32 * channels, use_bias=use_bias, bn_type=bn_type, act=act), # -> 4x4
            DBlock(32 * channels, 32 * channels, pool=False, use_bias=use_bias, bn_type=bn_type, act=act) # -> 4x4
        )

        # Common path - split into feature extraction and final classification
        self.common_path_feature = nn.Sequential(
            nn.Linear(32 * channels * 4 * 4 + 512, 32 * channels, bias=use_bias),
            get_bn_layer(bn_type, 32 * channels, dims=1),
            act()
        )
        self.common_path_classifier = nn.Linear(32 * channels, 1)

    def forward(self, image, latent, return_features=False):
        l = self.latent_path(latent)
        x = self.image_path(image)
        x = torch.flatten(x, start_dim=1)
        
        combined = torch.cat([x, l], dim=1)
        features = self.common_path_feature(combined)
        
        if return_features:
            return features
        else:
            return self.common_path_classifier(features)            

# Train.py

In [None]:
import torch
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from sklearn import metrics

from model import Generator, Encoder, Discriminator
from dataset import get_dataloaders
from losses import gradient_penalty, compute_l1

import matplotlib.pyplot as plt
import torchvision.utils as vutils

def get_anomaly_score(generator, encoder, discriminator, images, lambda_):
    generator.eval()
    encoder.eval()
    discriminator.eval()
    
    with torch.no_grad():
        latent = encoder(images)
        reconstructed_img = generator(latent)
        
        # Features from discriminator (using the return_features flag in your model.py)
        feat_real = discriminator(images, latent, return_features=True)
        feat_recon = discriminator(reconstructed_img, latent, return_features=True)
        
        # Loss R (Pixel) and Loss f_D (Feature)
        l_r = torch.mean(torch.abs(images - reconstructed_img), dim=[1, 2, 3])
        l_fd = torch.mean(torch.abs(feat_real - feat_recon), dim=1)
        
        score = (1 - lambda_) * l_r + lambda_ * l_fd
    return score

def save_reconstruction_grid(real_images, netE, netG, epoch, device, save_path="reconstructions"):
    """
    Creates a grid: Top row = Real Images, Bottom row = Reconstructions.
    """
    import os
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    netE.eval()
    netG.eval()
    with torch.no_grad():
        # Get reconstructions: G(E(x))
        recons = netG(netE(real_images[:8])) 
        real_samples = real_images[:8]
        
        # Combine into one grid: Normalize from [-1, 1] to [0, 1] for plotting
        combined = torch.cat([real_samples, recons], dim=0)
        grid = vutils.make_grid(combined, nrow=8, normalize=True, value_range=(-1, 1))
        
        plt.figure(figsize=(12, 4))
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.title(f"Epoch {epoch} Reconstructions (Top: Real, Bottom: G(E(x)))")
        plt.axis('off')
        plt.savefig(f"{save_path}/epoch_{epoch}.png")
        plt.close()

def plot_loss_curves(history, save_path="plots"):
    """
    Plots the training loss trends for Discriminator and Generator/Encoder.
    """
    import os
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    plt.figure(figsize=(10, 5))
    plt.plot(history['d_loss'], label='D Loss')
    plt.plot(history['ge_loss'], label='GE Loss')
    plt.xlabel('Iterations (x100)')
    plt.ylabel('Loss')
    plt.title('CBiGAN Training Losses')
    plt.legend()
    plt.savefig(f"{save_path}/loss_curve.png")
    plt.close()



def evaluate(netG, netE, netD, test_loader, device, lambda_, classes):
    all_scores = []
    all_labels = []
    benign_idx = classes.index('benign')

    for images, labels in test_loader:
        images = images.to(device)
        scores = get_anomaly_score(netG, netE, netD, images, lambda_)
        
        # Labels: 0 for benign, 1 for malware (anomaly)
        binary_labels = (labels != benign_idx).int()
        
        all_scores.append(scores.cpu().numpy())
        all_labels.append(binary_labels.numpy())
        
    all_scores = np.concatenate(all_scores)
    all_labels = np.concatenate(all_labels)
    
    fpr, tpr, _ = metrics.roc_curve(all_labels, all_scores)
    auc = metrics.auc(fpr, tpr)
    print(f">> Eval AUC: {auc:.4f}")



def train(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader, classes = get_dataloaders(args.image_size, args.batch_size)
    
    LATENT_SIZE = 128
    CHANNELS = 3
    IMG_SIZE = 512
    BATCH_SIZE = 4

    # Instantiate models
    netG = Generator(LATENT_SIZE, CHANNELS, upsample_first=False, bn_type='batch')
    netE = Encoder(IMG_SIZE, LATENT_SIZE, bn_type='instance')
    netD = Discriminator(IMG_SIZE, LATENT_SIZE, bn_type='layer')

    optGE = optim.Adam(list(netG.parameters()) + list(netE.parameters()), 
                       lr=args.lr, betas=(float(args.ge_beta1), float(args.ge_beta2) ))
    optD = optim.Adam(netD.parameters(), lr=args.lr, betas=(args.d_beta1, args.d_beta2))

    history = {'d_loss': [], 'ge_loss': [], 'auc': []}
    for epoch in args.epoch:
        netG.train(); netE.train(); netD.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
        
        for i, (images, _) in enumerate(pbar):
            z = torch.randn(batch_size, args.latent_size).to(device)    
            
            generated_image = netG(z).detach()
            generated_latent = netE(images).detach()
            
            reconstructed_image = netE(generated_latent).detach()
            reconstructed_latent = netG(generated_image).detach()
            
            # Model Update
            if i % args.d_iter == 0:
                optD.zero_grad()
                real_score = netD(images, generated_latent).detach()
                fake_score = netD(generated_image, latent).detach()
                
                d_loss = (fake_score - real_score).mean();
                
                gradient_penalty_loss = gradient_penalty(discriminator,
                                                         images, generated_images,
                                                         latent, generated_latent, device)
                discriminator_total_loss = d_loss + gp_weight * gradient_penalty_loss        
                discriminator_total_loss.backward()
                optD.step()
            else:
                optGE.zero_grad()
                generator_encoder_loss = (real_score - fake_score).mean() # L_E,G
                
                images_reconstruction_loss = (l1(images, reconstructed_images)).mean()  # L_R
                latent_reconstruction_loss = (l1(latent, reconstructed_latent)).mean()  # L_R'
                consistency_loss = images_reconstruction_loss + latent_reconstruction_loss  # L_C
                
                generator_encoder_total_loss = (1 - alpha) * generator_encoder_loss + alpha * consistency_loss  # L*_E,G
                ge_total_loss.backward()
                optGE.step()

            if i % 100 == 0:
                history['d_loss'].append(d_total_loss.item())
                history['ge_loss'].append(ge_total_loss.item())

            pbar.set_postfix({"D": f"{d_total_loss.item():.3f}", "GE": f"{ge_total_loss.item():.3f}"})

        save_reconstruction_grid(images, netE, netG, epoch, device)
        plot_loss_curves(history)
        
        auc = evaluate(netG, netE, netD, test_loader, device, args.lambda_, classes)
        history['auc'].append(auc)
        print(f"Epoch {epoch} | AUC: {auc:.4f}")

        torch.save({
            'generator': generator.state_dict(),
            'encoder': encoder.state_dict(),
            'discriminator': discriminator.state_dict(),
            'config': config
        }, 'final_model.pth')

        pd.DataFrame(history).to_csv('final_training_log.csv', index=False)
        print("Training completed!")

            
def main():    
    class Args:
        image_size = 512
        batch_size = 32
        latent_size = 128
        channels = 3
        lr = 1e-4
        ge_beta1, ge_beta2 = 0.0, 0.1
        d_beta1, d_beta2 = 0.0, 0.9
        gp_weight = 10
        alpha = 1e-5
        d_iter = 1 # Update GE every d_iter steps
        epochs = 50
        lambda_ = 0.1 # Weight for feature distance in scoring
        
        train(Args())
    
