In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.__version__)
print(torch.version.cuda)


cuda
2.9.1+cu128
12.8


In [2]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class NumpyImageDataset(Dataset):
    """
    Expects:
      X: (N, H, W, 3) uint8
      y: (N,) int labels
    """
    def __init__(self, x_path, y_path, transform=None):
        super().__init__()
        self.images = np.load(x_path)   # (N, H, W, 3)
        self.labels = np.load(y_path)   # (N,)
        self.transform = transform

        if self.images.ndim != 4 or self.images.shape[-1] != 3:
            raise ValueError(f"Expected images of shape (N, H, W, 3), got {self.images.shape}")
        if len(self.images) != len(self.labels):
            raise ValueError("X and y must have the same length")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]          # numpy (H, W, 3), likely uint8
        label = int(self.labels[idx])

        # to torch, CHW, float32 in [0, 1]
        img = torch.from_numpy(img)     # (H, W, 3), dtype uint8 or similar
        img = img.permute(2, 0, 1)      # (3, H, W)
        img = img.float() / 255.0       # (3, H, W), float32

        if self.transform is not None:
            img = self.transform(img)

        return img, label


In [4]:
from torch.utils.data import DataLoader

def get_dataloaders(
    data_dir="dataset",
    batch_size=32,
    num_workers=0,
    shuffle_train=True,
    transform=None,
):
    x_train_path = os.path.join(data_dir, "x_train.npy")
    y_train_path = os.path.join(data_dir, "y_train.npy")
    x_val_path   = os.path.join(data_dir, "x_val.npy")
    y_val_path   = os.path.join(data_dir, "y_val.npy")
    x_test_path  = os.path.join(data_dir, "x_test.npy")
    y_test_path  = os.path.join(data_dir, "y_test.npy")

    train_dataset = NumpyImageDataset(x_train_path, y_train_path, transform=transform)
    val_dataset   = NumpyImageDataset(x_val_path,   y_val_path,   transform=transform)
    test_dataset  = NumpyImageDataset(x_test_path,  y_test_path,  transform=transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffle_train,
        num_workers=num_workers,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = get_dataloaders(
    data_dir="dataset",
    batch_size=32,
)

In [5]:
for images, labels in train_loader:
    print(images.shape, images.dtype)   # (B, 3, H, W), torch.float32
    print(labels.shape, labels.dtype)   # (B,), torch.int64
    break


torch.Size([32, 3, 150, 150]) torch.float32
torch.Size([32]) torch.int64


## CNN + ViT architecture

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [7]:
#@title refactored Pytorch version of CNN

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class turkey_cnn(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 64, 3)
    self.pool1 = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(64, 96, 3)
    self.pool2 = nn.MaxPool2d(2, 2)
    self.conv3 = nn.Conv2d(96, 96, 3)
    self.pool3 = nn.MaxPool2d(2, 2)
    self.pool4 = nn.MaxPool2d(2,2)
  def forward(self, x):
    x = self.pool1(F.relu(self.conv1(x)))
    x = self.pool2(F.relu(self.conv2(x)))
    x = self.pool3(F.relu(self.conv3(x)))
    x = self.pool4(x)
    return x

# from torchinfo import summary

# Assume 'YourModelClass' is the name of your PyTorch class
cnn_model = turkey_cnn()

# (Batch Size, Channels, Height, Width)
input_shape = (100, 3, 150, 150)

# Print the model summary
# summary(cnn_model, input_size=input_shape)

#@title ViT component using CNN as backbone

# (B, C, H', W')

class ViT_hm(nn.Module):
  def __init__(self, cnn: turkey_cnn, C, D, N, num_heads, num_layers, num_cls):
    super().__init__()
    self.cnn = cnn
    self.D = D
    self.C = C
    self.to_tokens = nn.Linear(in_features=C, out_features=D )

    self.cls_token = nn.Parameter(torch.randn(1, 1, D))
    self.pos_emb = nn.Parameter(torch.randn(1, N+1, D))

    self.encoder_layer = nn.TransformerEncoderLayer(
        d_model = D,
        nhead = num_heads,
        batch_first = True # (B, N, D)
    )
    self.transformer = nn.TransformerEncoder(
        encoder_layer = self.encoder_layer,
        num_layers = num_layers
    )
    self.proj_cls = nn.Linear(D, num_cls)
    self.norm = nn.LayerNorm(D)

  def forward(self, x):
    feat = self.cnn(x)
    B, C, Hp, Wp = feat.shape
    feat = feat.flatten(2) # (B, C, N), N=H'*W'
    feat = feat.transpose(2, 1) # (B, N, C)
    tokens = self.to_tokens(feat) # (B, N, D)
    # cls token for global attending
    cls = self.cls_token.expand(B, -1, -1)
    x = torch.cat([cls, tokens], dim=1)

    #add position embedding
    x = x + self.pos_emb[:, : x.size(1),:]

    # multiple passes of MHSA
    x = self.transformer(x) # (B, N+1, D)
    x = self.norm(x)
    # extract cls and classify
    cls_tok = x[:, 0, :] # (B, D)
    logits = self.proj_cls(cls_tok) # (B, num_cls)
    return logits

### Define graph

In [13]:
cnn_backbone = turkey_cnn()

data_iter = iter(train_loader)
first_batch_dat, first_batch_lab = next(data_iter)
batch_size = first_batch_dat.size(0)
feat = cnn_backbone(first_batch_dat)
_, C, Hp, Wp = feat.shape
N = Hp * Wp
model = ViT_hm(cnn=cnn_backbone, C=C, D=128, N=N,
               num_heads=4, num_layers=2, num_cls=2)

print(f"batch size: {batch_size} * C {C} * Hp {Hp} * Wp {Wp}")
print(f" B * N * D = {batch_size} * {N} * {128}")
print(model)

batch size: 32 * C 96 * Hp 8 * Wp 8
 B * N * D = 32 * 64 * 128
ViT_hm(
  (cnn): turkey_cnn(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1))
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1))
    (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (to_tokens): Linear(in_features=96, out_features=128, bias=True)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)


### Training

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
model.to(device)

num_epochs = 25

for epoch in range(num_epochs):
    # ----- TRAIN -----
    model.train()
    running_loss_train = 0.0
    accuracy_train = 0.0

    for img, lab in train_loader:
        img = img.to(device)           # (B, 3, H, W), float32
        lab = lab.to(device)           # (B,), long

        optimizer.zero_grad()
        out = model(img)               # (B, 2)
        loss = loss_fn(out, lab)
        loss.backward()
        optimizer.step()

        running_loss_train += loss.item() * img.size(0)
        _, preds = torch.max(out, 1)   # argmax over classes
        accuracy_train += (preds == lab).sum().item()

    epoch_loss_train = running_loss_train / len(train_loader.dataset)
    epoch_acc_train = accuracy_train / len(train_loader.dataset)
    print(
        f"Epoch {epoch+1}/{num_epochs} - "
        f"Train loss: {epoch_loss_train:.4f}, Train acc: {epoch_acc_train:.4f}"
    )

    # ----- VALIDATION -----
    model.eval()
    running_loss_val = 0.0
    accuracy_val = 0.0

    with torch.no_grad():
        for img, lab in val_loader:
            img = img.to(device)
            lab = lab.to(device)

            out = model(img)
            loss = loss_fn(out, lab)

            running_loss_val += loss.item() * img.size(0)
            _, preds = torch.max(out, 1)
            accuracy_val += (preds == lab).sum().item()

    epoch_loss_val = running_loss_val / len(val_loader.dataset)
    epoch_acc_val = accuracy_val / len(val_loader.dataset)
    print(
        f"Epoch {epoch+1}/{num_epochs} - "
        f"Val loss: {epoch_loss_val:.4f}, Val acc: {epoch_acc_val:.4f}"
    )


Epoch 1/25 - Train loss: 0.7241, Train acc: 0.4960
Epoch 1/25 - Val loss: 0.6934, Val acc: 0.4989
Epoch 2/25 - Train loss: 0.7004, Train acc: 0.5028
Epoch 2/25 - Val loss: 0.6904, Val acc: 0.4989
Epoch 3/25 - Train loss: 0.6804, Train acc: 0.5524
Epoch 3/25 - Val loss: 0.6536, Val acc: 0.5225
Epoch 4/25 - Train loss: 0.5688, Train acc: 0.6981
Epoch 4/25 - Val loss: 0.5065, Val acc: 0.7537
Epoch 5/25 - Train loss: 0.5208, Train acc: 0.7520
Epoch 5/25 - Val loss: 0.5185, Val acc: 0.7580
Epoch 6/25 - Train loss: 0.4909, Train acc: 0.7728
Epoch 6/25 - Val loss: 0.4366, Val acc: 0.7923
Epoch 7/25 - Train loss: 0.4403, Train acc: 0.7949
Epoch 7/25 - Val loss: 0.4777, Val acc: 0.7730
Epoch 8/25 - Train loss: 0.4397, Train acc: 0.7906
Epoch 8/25 - Val loss: 0.4185, Val acc: 0.8094
Epoch 9/25 - Train loss: 0.4295, Train acc: 0.8169
Epoch 9/25 - Val loss: 0.4173, Val acc: 0.8094
Epoch 10/25 - Train loss: 0.4289, Train acc: 0.8047
Epoch 10/25 - Val loss: 0.4385, Val acc: 0.8201
Epoch 11/25 - Trai

### evaluation with test set

In [15]:
import torch

test_loss = 0.0
correct = 0
total = 0

TP = 0
FP = 0
FN = 0

model.eval()
with torch.no_grad():
    for img, lab in test_loader:
        img = img.to(device)
        lab = lab.to(device)

        out = model(img)                       # (B, 2)
        loss = loss_fn(out, lab)

        test_loss += loss.item() * img.size(0)

        _, preds = torch.max(out, 1)           # (B,)
        correct += (preds == lab).sum().item()
        total += lab.size(0)

        # Binary metrics, assuming positive class = 1
        TP += ((preds == 1) & (lab == 1)).sum().item()
        FP += ((preds == 1) & (lab == 0)).sum().item()
        FN += ((preds == 0) & (lab == 1)).sum().item()

# Averages
test_loss /= len(test_loader.dataset)
test_accuracy = correct / total

# Precision, recall, F1 for class 1
if TP + FP > 0:
    test_precision = TP / (TP + FP)
else:
    test_precision = 0.0

if TP + FN > 0:
    test_recall = TP / (TP + FN)
else:
    test_recall = 0.0

if test_precision + test_recall > 0:
    test_f1 = 2 * test_precision * test_recall / (test_precision + test_recall)
else:
    test_f1 = 0.0

print(
    f"Test Loss: {test_loss:.4f}, "
    f"Test Accuracy: {test_accuracy:.4f}, "
    f"Test Precision: {test_precision:.4f}, "
    f"Test Recall: {test_recall:.4f}, "
    f"Test F1: {test_f1:.4f}"
)

Test Loss: 0.3543, Test Accuracy: 0.8462, Test Precision: 0.8293, Test Recall: 0.8718, Test F1: 0.8500


### save model

In [16]:
import torch
import torch.nn as nn

def save_scripted_model(model, path="cnn_vit_ship_classifier_jit.pt"):
    model.eval()
    cpu_model = model.to("cpu")
    scripted_model = torch.jit.script(cpu_model)
    scripted_model.save(path)
    print(f"Saved TorchScript model to: {path}")


save_scripted_model(model, path="cnn_vit_ship_classifier_jit.pt")


Saved TorchScript model to: cnn_vit_ship_classifier_jit.pt


## Test on OOD Fire data

In [17]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# ----------------------------
# 1. OOD dataset for 150x150
# ----------------------------

class OODShipsDataset(Dataset):
    """
    Expects a folder:
        dataset/ood_on_fire/
            civilian_....jpg / png / ...
            military_....jpg / png / ...
    Labels:
        civilian_* -> 0
        military_* -> 1
    """
    def __init__(self, root, transform=None):
        super().__init__()
        self.root = root
        self.transform = transform
        self.samples = []

        exts = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}

        for fname in os.listdir(root):
            fpath = os.path.join(root, fname)
            if not os.path.isfile(fpath):
                continue

            ext = os.path.splitext(fname)[1].lower()
            if ext not in exts:
                continue

            lower = fname.lower()
            if lower.startswith("civilian"):
                label = 0
            elif lower.startswith("military"):
                label = 1
            else:
                # skip files that don't match naming convention
                continue

            self.samples.append((fpath, label))

        if not self.samples:
            raise RuntimeError(f"No valid images found in {root}")

        print(f"Found {len(self.samples)} OOD images in {root}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, label


# ----------------------------
# 2. Transform + DataLoader
# ----------------------------

ood_root = "dataset/ood_on_fire"

transform = transforms.Compose([
    transforms.Resize((150, 150)),  # <--- 150 x 150, as you requested
    transforms.ToTensor(),          # (3, H, W), float32 in [0,1]
])

ood_dataset = OODShipsDataset(ood_root, transform=transform)

ood_loader = DataLoader(
    ood_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

# ----------------------------
# 3. Inference + metrics
# ----------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

all_labels = []
all_preds = []

with torch.no_grad():
    for imgs, labs in ood_loader:
        imgs = imgs.to(device)          # (B, 3, 150, 150)
        labs = labs.to(device)          # (B,)

        logits = model(imgs)            # (B, 2)
        preds = torch.argmax(logits, dim=1)  # (B,)

        all_labels.extend(labs.cpu().numpy().tolist())
        all_preds.extend(preds.cpu().numpy().tolist())

y_true = np.array(all_labels)
y_pred = np.array(all_preds)

acc = accuracy_score(y_true, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(
    y_true,
    y_pred,
    average="binary",
    pos_label=1,  # military = 1
)

print(f"OOD on fire evaluation (n={len(y_true)}):")
print(f"Accuracy : {acc:.4f}")
print(f"Precision (military=1): {precision:.4f}")
print(f"Recall    (military=1): {recall:.4f}")
print(f"F1        (military=1): {f1:.4f}")


Found 61 OOD images in dataset/ood_on_fire
OOD on fire evaluation (n=61):
Accuracy : 0.6885
Precision (military=1): 0.7143
Recall    (military=1): 0.6452
F1        (military=1): 0.6780
