# ----------------------
# Imports
# ----------------------

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm  # For MobileNetV3/EfficientNet-Lite
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

# ----------------------
# Dataset Loader
# ----------------------

In [None]:
class DroneBatchDataset(Dataset):
    def __init__(self, npz_folder, sequence_length=5, transform=None):
        self.files = [os.path.join(npz_folder, f) for f in os.listdir(npz_folder) if f.endswith('.npy')]
        self.sequence_length = sequence_length
        self.transform = transform
        self.samples = []
        for file in self.files:
            data = np.load(file)
            depths = data['depths']
            actions = data['actions']
            victim_dirs = data['victim_dirs']
            for i in range(len(depths) - sequence_length + 1):
                self.samples.append((file, i))
        self.data_cache = {}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file, start_idx = self.samples[idx]
        if file not in self.data_cache:
            data = np.load(file)
            self.data_cache[file] = {
                'depths': data['depths'],
                'actions': data['actions'],
                'victim_dirs': data['victim_dirs']
            }
        d = self.data_cache[file]
        depths_seq = d['depths'][start_idx:start_idx+self.sequence_length]
        actions_seq = d['actions'][start_idx:start_idx+self.sequence_length]
        victim_dirs_seq = d['victim_dirs'][start_idx:start_idx+self.sequence_length]
        depths_seq = np.expand_dims(depths_seq, 1)
        if self.transform:
            depths_seq = self.transform(torch.from_numpy(depths_seq).float())
        return (
            torch.from_numpy(depths_seq).float(),
            torch.from_numpy(victim_dirs_seq).float(),
            torch.from_numpy(actions_seq).long()
        )

# ----------------------
# Model Definition
# ----------------------

In [None]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(input_dim + hidden_dim, 4 * hidden_dim, kernel_size, padding=padding, bias=bias)
        self.hidden_dim = hidden_dim

    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1)
        conv_output = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_output, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        c_next = f * c + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

class SimpleConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size=3, bias=True):
        super().__init__()
        self.cell = ConvLSTMCell(input_dim, hidden_dim, kernel_size, bias)

    def forward(self, x):
        batch, seq_len, C, H, W = x.size()
        h, c = (torch.zeros(batch, self.cell.hidden_dim, H, W, device=x.device),
                torch.zeros(batch, self.cell.hidden_dim, H, W, device=x.device))
        outputs = []
        for t in range(seq_len):
            h, c = self.cell(x[:, t], h, c)
            outputs.append(h)
        return torch.stack(outputs, dim=1)

class DroneActionNet(nn.Module):
    def __init__(self, num_actions=9, victim_dir_dim=4, backbone='mobilenetv3_small', convlstm_hidden=32):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=True, features_only=True, in_chans=1)
        backbone_out_ch = self.backbone.feature_info[-1]['num_chs']
        self.convlstm = SimpleConvLSTM(backbone_out_ch, convlstm_hidden)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.victim_fc = nn.Linear(victim_dir_dim, 16)
        self.fc = nn.Sequential(
            nn.Linear(convlstm_hidden + 16, 64),
            nn.ReLU(),
            nn.Linear(64, num_actions)
        )

    def forward(self, x, victim_dirs):
        batch, seq, _, H, W = x.size()
        x = x.view(batch * seq, 1, H, W)
        feats = self.backbone(x)[-1]
        _, C, h, w = feats.size()
        feats = feats.view(batch, seq, C, h, w)
        convlstm_out = self.convlstm(feats)
        pooled = self.pool(convlstm_out[:, -1]).view(batch, -1)
        victim_emb = self.victim_fc(victim_dirs[:, -1])
        out = torch.cat([pooled, victim_emb], dim=1)
        logits = self.fc(out)
        return logits

# ----------------------
# Training and Evaluation
# ----------------------

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0
    for x, victim_dirs, y in loader:
        x, victim_dirs, y = x.to(device), victim_dirs.to(device), y[:, -1].to(device)
        optimizer.zero_grad()
        logits = model(x, victim_dirs)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += x.size(0)
    return total_loss / total_samples, total_correct / total_samples

def evaluate(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, victim_dirs, y in loader:
            x, victim_dirs, y = x.to(device), victim_dirs.to(device), y[:, -1].to(device)
            logits = model(x, victim_dirs)
            preds = logits.argmax(1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(y.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    cm = confusion_matrix(all_labels, all_preds)
    return acc, f1, cm

def plot_confusion_matrix(cm, class_names):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

# ----------------------
# Quantization
# ----------------------

In [None]:
def quantize_model(model):
    model.eval()
    model.cpu()
    quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
    return quantized_model

In [None]:
# User: Set these paths and hyperparameters
    train_dir = '/path/to/train'
    val_dir = '/path/to/val'
    sequence_length = 5
    batch_size = 8
    num_epochs = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = DroneBatchDataset(train_dir, sequence_length=sequence_length)
    val_dataset = DroneBatchDataset(val_dir, sequence_length=sequence_length)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

    model = DroneActionNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_acc, val_f1, val_cm = evaluate(model, val_loader, device)
        print(f"Epoch {epoch}: Train Loss {train_loss:.4f}, Train Acc {train_acc:.4f}, Val Acc {val_acc:.4f}, Val F1 {val_f1:.4f}")

    class_names = ['Right','Left','Forward','Backward','Up','Down','TurnL','TurnR','Hover']
    val_acc, val_f1, val_cm = evaluate(model, val_loader, device)
    plot_confusion_matrix(val_cm, class_names)

    # Quantize and evaluate
    quantized_model = quantize_model(model)
    val_acc, val_f1, val_cm = evaluate(quantized_model, val_loader, device)
    print(f"Quantized Model: Val Acc {val_acc:.4f}, Val F1 {val_f1:.4f}")