In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import os
from PIL import Image
import random
from tqdm import tqdm
import time


In [3]:

# DDOSDataset (unchanged)
class DDOSDataset(Dataset):
    def __init__(self, root_dir=r'\DS2\DDOS\data', split='train', transform=None, max_images=10000):
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.max_images = max_images
        self.image_paths = []
        self.mask_paths = []

        environments = ['neighbourhood', 'park']
        for env in environments:
            env_dir = os.path.join(self.root_dir, env)
            if not os.path.exists(env_dir):
                continue
            flight_folders = sorted([f for f in os.listdir(env_dir) if f.isdigit()], key=int)
            for flight in flight_folders:
                flight_dir = os.path.join(env_dir, flight)
                image_dir = os.path.join(flight_dir, 'image')
                seg_dir = os.path.join(flight_dir, 'segmentation')
                if not (os.path.exists(image_dir) and os.path.exists(seg_dir)):
                    continue
                images = sorted(os.listdir(image_dir))
                masks = sorted(os.listdir(seg_dir))
                for img, mask in zip(images, masks):
                    self.image_paths.append(os.path.join(image_dir, img))
                    self.mask_paths.append(os.path.join(seg_dir, mask))

        if len(self.image_paths) > self.max_images:
            indices = random.sample(range(len(self.image_paths)), self.max_images)
            self.image_paths = [self.image_paths[i] for i in indices]
            self.mask_paths = [self.mask_paths[i] for i in indices]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        img = np.array(img, dtype=np.float32) / 255.0
        mask = Image.open(self.mask_paths[idx])
        mask = np.array(mask, dtype=np.uint8)

        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)

        img = cv2.resize(img, (128, 128))
        mask = cv2.resize(mask, (128, 128), interpolation=cv2.INTER_NEAREST)

        class_mapping = {255: 0, 240: 1, 225: 2, 210: 3, 195: 4, 180: 5, 165: 6, 150: 7, 140: 8, 0: 9}
        mask = np.vectorize(lambda x: class_mapping.get(x, 9))(mask)

        img = torch.tensor(img.transpose(2, 0, 1), dtype=torch.float32)
        mask = torch.tensor(mask, dtype=torch.long)
        return img, mask


In [4]:

# Updated LiquidNeuralNetwork with CFC
class LiquidNeuralNetwork(nn.Module):
    def __init__(self, input_channels=3, hidden_size=256, num_classes=10):
        super(LiquidNeuralNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.tau = 1.0
        self.input_layer = nn.Conv2d(input_channels, hidden_size, kernel_size=3, padding=1)
        self.liquid_layer = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1)
        self.output_layer = nn.Conv2d(hidden_size, num_classes, kernel_size=1)

    def forward(self, x, hidden=None):
        if hidden is None:
            hidden = torch.zeros(x.size(0), self.hidden_size, x.size(2), x.size(3)).to(x.device)

        input_state = torch.tanh(self.input_layer(x))

        # CFC fused update
        dt = 1.0
        decay = torch.exp(-dt / self.tau)
        drive = self.liquid_layer(input_state + hidden)
        hidden_new = decay * hidden + (1 - decay) * torch.tanh(drive)

        hidden = hidden_new
        output = self.output_layer(hidden)
        return output, hidden


In [None]:
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

dataset = DDOSDataset(split='train', max_images=500)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)

val_dataset = DDOSDataset(split='validation')
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0, pin_memory=True)

model = LiquidNeuralNetwork(input_channels=3, hidden_size=256, num_classes=10).to(device)

# Training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
    
    for images, masks in progress_bar:
        images, masks = images.to(device), masks.to(device)
        hidden = None
        outputs, hidden = model(images, hidden)
        loss = criterion(outputs, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        progress_bar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")


In [None]:

# Save the model
model_path = 'liquid_nn_cfc_model.pth'
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")


In [None]:

# Validation
model.eval()
start_time = time.time()
with torch.no_grad():
    for images, masks in val_loader:
        images, masks = images.to(device), masks.to(device)
        torch.cuda.synchronize()
        t0 = time.time()
        outputs, _ = model(images)
        torch.cuda.synchronize()
        preds = torch.argmax(outputs, dim=1)
        print(f"Batch inference time: {time.time() - t0:.4f} seconds")

        img = images[0].cpu().numpy().transpose(1, 2, 0)
        mask = masks[0].cpu().numpy()
        pred = preds[0].cpu().numpy()

        display_size = (512, 512)
        img_resized = cv2.resize(img, display_size, interpolation=cv2.INTER_LINEAR)
        mask_resized = cv2.resize(mask, display_size, interpolation=cv2.INTER_NEAREST)
        pred_resized = cv2.resize(pred, display_size, interpolation=cv2.INTER_NEAREST)

        img_resized = cv2.cvtColor(img_resized, cv2.COLOR_RGB2BGR)

        def colorize_mask(mask, num_classes=10):
            colors = [
                [0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0],
                [255, 0, 255], [0, 255, 255], [128, 128, 128], [255, 128, 0], [128, 0, 128]
            ]
            colored_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
            for i in range(num_classes):
                colored_mask[mask == i] = colors[i]
            return colored_mask

        mask_colored = colorize_mask(mask_resized)
        pred_colored = colorize_mask(pred_resized)

        cv2.namedWindow('RGB', cv2.WINDOW_NORMAL)
        cv2.namedWindow('Ground Truth', cv2.WINDOW_NORMAL)
        cv2.namedWindow('Prediction', cv2.WINDOW_NORMAL)

        cv2.imshow('RGB', img_resized)
        cv2.imshow('Ground Truth', mask_colored)
        cv2.imshow('Prediction', pred_colored)

        cv2.waitKey(0)
        break

print(f"Total validation time: {time.time() - start_time:.4f} seconds")
cv2.destroyAllWindows()